Coder Social home page Coder Social logo

tokaka22 / adversarial-distributional-training Goto Github PK

View Code? Open in Web Editor NEW

This project forked from dongyp13/adversarial-distributional-training

0.0 0.0 0.0 200 KB

Adversarial Distributional Training (NeurIPS 2020)

License: MIT License

Python 100.00%

adversarial-distributional-training's Introduction

Adversarial Distributional Training

This repository contains the code for adversarial distributional training (ADT) introduced in the following paper

Adversarial Distributional Training for Robust Deep Learning (NeurIPS 2020)

Yinpeng Dong*, Zhijie Deng*, Tianyu Pang, Hang Su, and Jun Zhu (* indicates equal contribution)

Citation

If you find our methods useful, please consider citing:

@inproceedings{dong2020adversarial,
  title={Adversarial Distributional Training for Robust Deep Learning},
  author={Dong, Yinpeng and Deng, Zhijie and Pang, Tianyu and Su, Hang and Zhu, Jun},
  booktitle={Advances in Neural Information Processing Systems},
  year={2020}
}

Introduction

Adversarial distribution training (ADT) is a new framework to train robust deep learning models. It is formulated as a minimax optimization problem, in which the inner maximization aims to find an adversarial distribution for each natural input to characterize potential adversarial examples; and the outer minimization aims to optimize DNN parameters with the worst-case adversarial distributions.

In this paper, we proposed three different approaches to parameterize the adversarial distributions, as illustrated below.

Figure 1: An illustration of three different ADT methods, including (a) ADTEXP; (b) ADTEXP-AM; (c) ADTIMP-AM.

Prerequisites

  • Python (3.6.8)
  • Pytorch (1.3.0)
  • torchvision (0.4.1)
  • numpy

Training

We have proposed three different methods for ADT. The command for each training method is specified below.

Training ADTEXP

python adt_exp.py --model-dir adt-exp --dataset cifar10 (or cifar100/svhn)

Training ADTEXP-AM

python adt_expam.py --model-dir adt-expam --dataset cifar10 (or cifar100/svhn)

Training ADTIMP-AM

python adt_impam.py --model-dir adt-impam --dataset cifar10 (or cifar100/svhn)

The checkpoints will be saved at each model folder.

Evaluation

Evaluation under White-box Attacks

  • For FGSM attack, run
python evaluate_attacks.py --model-path ${MODEL-PATH} --attack-method FGSM --dataset cifar10 (or cifar100/svhn)
  • For PGD attack, run
python evaluate_attacks.py --model-path ${MODEL-PATH} --attack-method PGD --num-steps 20 (or 100) --dataset cifar10 (or cifar100/svhn)
  • For MIM attack, run
python evaluate_attacks.py --model-path ${MODEL-PATH} --attack-method MIM --num-steps 20 --dataset cifar10 (or cifar100/svhn)
  • For C&W attack, run
python evaluate_attacks.py --model-path ${MODEL-PATH} --attack-method CW --num-steps 30 --dataset cifar10 (or cifar100/svhn)
  • For FeaAttack, run
python feature_attack.py --model-path ${MODEL-PATH} --dataset cifar10 (or cifar100/svhn)

Evaluation under Transfer-based Black-box Attacks

First change the --white-box-attack argument in evaluate_attacks.py to False. Then run

python evaluate_attacks.py --source-model-path ${SOURCE-MODEL-PATH} --target-model-path ${TARGET-MODEL-PATH} --attack-method PGD (or MIM)

Evaluation under SPSA

python spsa.py --model-path ${MODEL-PATH} --samples_per_draw 256 (or 512/1024/2048)

Pretrained Models

We have provided the pre-trained models on CIFAR-10, whose performance is reported in Table 1. They can be downloaded at

Contact

Yinpeng Dong: [email protected]

Zhijie Deng: [email protected]

adversarial-distributional-training's People

Contributors

dongyp13 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.