Coder Social home page Coder Social logo

gwinhen / moth Goto Github PK

View Code? Open in Web Editor NEW
11.0 1.0 4.0 46.57 MB

This is the implementation for IEEE S&P 2022 paper "Model Orthogonalization: Class Distance Hardening in Neural Networks for Better Security."

License: MIT License

Python 100.00%

moth's Introduction

Model Orthogonalization: Class Distance Hardening in Neural Networks for Better Security

This is the implementation for IEEE S&P 2022 paper "Model Orthogonalization: Class Distance Hardening in Neural Networks for Better Security."

Prerequisite

The code is implemented and tested on Keras (with TensorFlow backend) and PyTorch. It runs on Python 3.6.9.

Keras Version

  • Keras 2.3.0
  • Tensorflow 1.14.0

PyTorch Version

  • PyTorch 1.7.0

Usage

The main functions are located in src/main.py file.

Model Orthogonalization

To harden a model using MOTH, please use the following command:

python3 src/main.py --phase moth

The default dataset and model are CIFAR-10 and ResNet20. You can harden different model structures on other datasets by passing the arguments --dataset [dataset] and --network [model structure]. We have included four datasets (CIFAR-10, SVHN, LISA, and GTSRB) and four model structures (ResNet, VGG19, NiN, and CNN). (The datasets will be uploaded soon.)

To measure the pair-wise class distance, please run:

python3 src/main.py --phase validate --suffix [suffix of checkpoint] --seed [seed id]

Models hardened by MOTH will have a suffix of _moth in addition to the original checkpoint path. Please provide the checkpoint extension using argument --suffix. The distance shall be measured using three different random seeds by passing seed ids 0, 1, and 2 to the argument --seed separately.

The final pair-wise class distance of the evalauted model can be obtained through the following command:

python3 src/main.py --phsae show --suffix [suffix of checkpoint]

It prints out a matrix of class distances of all the pairs. Each row denotes the source label and each column the target label. The average distance and relative enlargement are also presented in the end.

Model Functionality

To test the accuracy of a model, simply run:

python3 src/main.py --phase test --suffix [suffix of checkpoint]

The robustness of a given model can be evaluated using PGD with the following command:

python3 src/main.py --phase measure --suffix [suffix of checkpoint]

Acknowledgement

The code of trigger inversion is inspired by Neural Cleanse.

The PGD code is adapted from cifar10_challenge.

Thanks for their amazing implementations.

Reference

Please cite for any purpose of usage.

@inproceedings{tao2022model,
  title={Model Orthogonalization: Class Distance Hardening in Neural Networks for Better Security},
  author={Tao, Guanhong and Liu, Yingqi and Shen, Guangyu and Xu, Qiuling and An, Shengwei and Zhang, Zhuo and Zhang, Xiangyu},
  booktitle={2022 IEEE Symposium on Security and Privacy (SP)},
  year={2022},
  organization={IEEE}
}

moth's People

Contributors

gwinhen avatar modelorth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

moth's Issues

possibly model.train() forgotten

Hi,
Here, the model is supposed to be trained.

Even though model.train() is called on line 79, later when trigger.generate() and TriggerCombo.generate() are called, they in turn call model.eval(), but model.train() in not called again before optimization step.

I think it might be an issue, as it raises an error.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.