Coder Social home page Coder Social logo

britney-code / tf_to_pytorch_model Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ylhz/tf_to_pytorch_model

0.0 0.0 0.0 353.87 MB

Convert tensorflow model to pytorch model via [MMdnn](https://github.com/microsoft/MMdnn) for adversarial attacks.

Python 100.00%

tf_to_pytorch_model's Introduction

tf_to_torch_model

In this repo, we convert some common Tensorflow models used in adversarial attacks to PyTorch models and provide the resultant models. Since these models are converted from their Tensorflow version, the inputs need the same normalization, i.e., [-1,1]. We have already done this, so you can use it directly.

model = nn.Sequential(
    # Images for inception classifier are normalized to be in [-1, 1] interval.
    Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), 
    net.KitModel(model_path).eval().cuda())

logit = model(input)

We also provide the PyTorch code for you to implement attacks on our converted models, e.g., I-FGSM (run the following command):

python torch_attack.py

File Description

dataset: Test images.

nets: Original tensorflow models.

nets_weight: Put the original Tensorflow network weight file into this directory.

torch_nets: Converted torch model.

torch_nets_weight: Put the converted Pytorch network weight file into this directory. (You can find them in Releases)

tf_attack.py: Sample attack method with tensorflow.

torch_attack.py: Sample attack method with PyTorch.

Model Accuracy

The following table shows the source of the converted model and the accuracy of the model on the 1000 test pictures (selected from Imagenet) given.

Converted model Model source torch Accuracy(%) tf Accuracy(%) input size
tf2torch_inception_v3 inception_v3_2016_08_28 96.20 96.20 299*299
tf2torch_inception_v4 inception_v4_2016_09_09 97.40 97.40 299*299
tf2torch_resnet_v2_50 resnet_v2_50_2017_04_14 94.90 94.90 299*299
tf2torch_resnet_v2_101 resnet_v2_101_2017_04_14 96.30 96.30 299*299
tf2torch_resnet_v2_152 resnet_v2_152_2017_04_14 95.80 95.80 299*299
tf2torch_inc_res_v2 inception_resnet_v2_2016_08_30 99.80 99.80 299*299
tf2torch_adv_inception_v3 adv_inception_v3_2017_08_18 94.90 94.90 299*299
tf2torch_ens3_adv_inc_v3 ens3_adv_inception_v3_2017_08_18 93.70 93.70 299*299
tf2torch_ens4_adv_inc_v3 ens4_adv_inception_v3_2017_08_18 91.60 91.60 299*299
tf2torch_ens_adv_inc_res_v2 ens_adv_inception_resnet_v2_2017_08_18 97.60 97.60 299*299

Implementation of sample attack

This table shows our result / paper result ("*" indicates white-box attack). The paper result is from Patch-wise Attack for Fooling Deep Neural Network, and we can see that we have obtained similar results with the converted model. The specific parameter settings can be found in the paper.

attack method inc_v3* inc_v4 resnet_v2_152 inc_res_v2 ens3_adv_inc_v3 ens4_adv_inc_v3 ens_adv_inc_res_v2
FGSM 81.0/80.9 37.4/38.0 33.0/33.1 33.9/33.1 16.9/16.8 15.7/15.8 8.2/8.3
I-FGSM 100.0/100.0 30.1/29.6 19.4/19.4 21.4/20.3 12.0/11.7 12.4/12.1 5.5/5.5
MI-FGSM 100.0/100.0 55.1/54.1 42.8/43.5 51.7/50.9 22.2/21.9 21.6/21.1 11.2/10.5
DI-FGSM 99.7/99.8 55.3/54.2 33.4/32.1 43.5/43.6 15.9/15.0 16.4/16.2 8.6/7.1
TI-FGSM / / / / 31.2/30.8 31.1/30.6 22.9/22.7
PI-FGSM 100.0/100.0 57.5/58.6 47.6/45.0 52.2/51.3 38.4/39.3 39.0/39.5 28.0/28.8

Note !

  1. If you want to use aux_logits, using aux_logits=True to create the model:
model = nn.Sequential(
    # Images for inception classifier are normalized to be in [-1, 1] interval.
    Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), 
    net.KitModel(model_path, aux_logits=True).eval().cuda())
    
logits, aux_logits = model(input)
  1. Models with aux_logits:

    • tf2torch_inception_v3,
    • tf2torch_inception_v4,
    • tf2torch_inc_res_v2,
    • tf2torch_adv_inception_v3,
    • tf2torch_ens3_adv_inc_v3,
    • tf2torch_ens4_adv_inc_v3.

tf_to_pytorch_model's People

Contributors

ylhz avatar qilong-zhang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.