Coder Social home page Coder Social logo

tf_to_pytorch_model's Introduction

tf_to_torch_model

In this repo, we convert some common Tensorflow models used in adversarial attacks to PyTorch models and provide the resultant models. Since these models are converted from their Tensorflow version, the inputs need the same normalization, i.e., [-1,1]. We have already done this, so you can use it directly.

model = nn.Sequential(
    # Images for inception classifier are normalized to be in [-1, 1] interval.
    Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), 
    net.KitModel(model_path).eval().cuda())

logit = model(input)

We also provide the PyTorch code for you to implement attacks on our converted models, e.g., I-FGSM (run the following command):

python torch_attack.py

File Description

dataset: Test images.

nets: Original tensorflow models.

nets_weight: Put the original Tensorflow network weight file into this directory.

torch_nets: Converted torch model.

torch_nets_weight: Put the converted Pytorch network weight file into this directory. (You can find them in Releases)

tf_attack.py: Sample attack method with tensorflow.

torch_attack.py: Sample attack method with PyTorch.

Model Accuracy

The following table shows the source of the converted model and the accuracy of the model on the 1000 test pictures (selected from Imagenet) given.

Converted model Model source torch Accuracy(%) tf Accuracy(%) input size
tf2torch_inception_v3 inception_v3_2016_08_28 96.20 96.20 299*299
tf2torch_inception_v4 inception_v4_2016_09_09 97.40 97.40 299*299
tf2torch_resnet_v2_50 resnet_v2_50_2017_04_14 94.90 94.90 299*299
tf2torch_resnet_v2_101 resnet_v2_101_2017_04_14 96.30 96.30 299*299
tf2torch_resnet_v2_152 resnet_v2_152_2017_04_14 95.80 95.80 299*299
tf2torch_inc_res_v2 inception_resnet_v2_2016_08_30 99.80 99.80 299*299
tf2torch_adv_inception_v3 adv_inception_v3_2017_08_18 94.90 94.90 299*299
tf2torch_ens3_adv_inc_v3 ens3_adv_inception_v3_2017_08_18 93.70 93.70 299*299
tf2torch_ens4_adv_inc_v3 ens4_adv_inception_v3_2017_08_18 91.60 91.60 299*299
tf2torch_ens_adv_inc_res_v2 ens_adv_inception_resnet_v2_2017_08_18 97.60 97.60 299*299

Implementation of sample attack

This table shows our result / paper result ("*" indicates white-box attack). The paper result is from Patch-wise Attack for Fooling Deep Neural Network, and we can see that we have obtained similar results with the converted model. The specific parameter settings can be found in the paper.

attack method inc_v3* inc_v4 resnet_v2_152 inc_res_v2 ens3_adv_inc_v3 ens4_adv_inc_v3 ens_adv_inc_res_v2
FGSM 81.0/80.9 37.4/38.0 33.0/33.1 33.9/33.1 16.9/16.8 15.7/15.8 8.2/8.3
I-FGSM 100.0/100.0 30.1/29.6 19.4/19.4 21.4/20.3 12.0/11.7 12.4/12.1 5.5/5.5
MI-FGSM 100.0/100.0 55.1/54.1 42.8/43.5 51.7/50.9 22.2/21.9 21.6/21.1 11.2/10.5
DI-FGSM 99.7/99.8 55.3/54.2 33.4/32.1 43.5/43.6 15.9/15.0 16.4/16.2 8.6/7.1
TI-FGSM / / / / 31.2/30.8 31.1/30.6 22.9/22.7
PI-FGSM 100.0/100.0 57.5/58.6 47.6/45.0 52.2/51.3 38.4/39.3 39.0/39.5 28.0/28.8

Note !

  1. If you want to use aux_logits, using aux_logits=True to create the model:
model = nn.Sequential(
    # Images for inception classifier are normalized to be in [-1, 1] interval.
    Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), 
    net.KitModel(model_path, aux_logits=True).eval().cuda())
    
logits, aux_logits = model(input)
  1. Models with aux_logits:

    • tf2torch_inception_v3,
    • tf2torch_inception_v4,
    • tf2torch_inc_res_v2,
    • tf2torch_adv_inception_v3,
    • tf2torch_ens3_adv_inc_v3,
    • tf2torch_ens4_adv_inc_v3.

tf_to_pytorch_model's People

Contributors

qilong-zhang avatar ylhz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

tf_to_pytorch_model's Issues

About readme

Model Accuracy in readme raises an issue of misleading.

For example, [tf_inception_v3] means that this model is converted from tf. But I just think that It is the tf models.

batchsize = 1 传入模型的问题

作者,你好 ,batchsize = 1的时候,传入模型获得logits的有的模型维度为[1,1001] 有的模型维度为[1001] 如何修改呢?

about resnet

Thank you for sharing!
when source model is resnet50 or resnet101 or resnet152, the attack success rate( ASR) is has dropped significantly in the case of white-box attack.
for example, when adversarial samples are generated by Resnet50, as shown in the figure,the ASR for resnet50 is 54.2%.
test69780983-47a9-4192-80ac-910232705995

Sample input size accepted by the model

Hi, I have been borrowing your framework to test my method recently, and I want to ask if these models support other sizes such as 224x244 input size, and I have encountered an error in this part

转换后的模型

您好 ,我想问一下转换后的模型怎么下载呀。我下载npy文件显示不能下载无法授权。请问可否分享我一份呢?我的邮箱是[email protected]谢谢

The dataset

Please can you tell are the images here the sames used in NIPS 2017 adversarial competition? Thank you.

About DataLoader

The super of DataLoader should be torch.utils.data.Dataset rather than DataLoader

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.