Coder Social home page Coder Social logo

wvinzh / ws_dan_pytorch Goto Github PK

View Code? Open in Web Editor NEW
167.0 3.0 34.0 298 KB

PyTorch Implementation Of WS-DAN(See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification)

License: MIT License

Shell 0.91% Python 99.09%
pytorch-cnn fine-grained-classification image-classification image-recognition fine-grained-visual-categorization pytorch

ws_dan_pytorch's Introduction

PyTorch Implementation Of WS-DAN

Introduction

This is a PyTorch implementation of the paper "See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification". It also has an official TensorFlow implementation WS_DAN. The core part of the code refers to the official version, and finally,the performance almost reaches the results reported in the paper.

Environment

  • Ubuntu 16.04, GTX 1080 8G * 2, cuda 8.0
  • Anaconda with Python=3.6.5, PyTorch=0.4.1, torchvison=0.2.1, etc.
  • Some third-party dependencies may be installed with pip or conda when needed.

Result

Dataset ACC(this repo) ACC Refine(this repo) ACC(paper)
CUB-200-2011 88.20 89.30 89.4
FGVC-Aircraft 93.15 93.22 93.0
Stanford Cars 94.13 94.43 94.5
Stanford Dogs 86.03 86.46 92.2

You can download pretrained models from WS_DAN_Onedrive

Install

  1. Clone the repo
git clone https://github.com/wvinzh/WS_DAN_PyTorch
  1. Prepare dataset
  • Download the following datasets.
Dataset Object Category Training Testing
CUB-200-2011 Bird 200 5994 5794
Stanford-Cars Car 100 6667 3333
fgvc-aircraft Aircraft 196 8144 8041
Stanford-Dogs Dogs 120 12000 8580
  • Extract the data like following:
Fine-grained
├── CUB_200_2011
│   ├── attributes
│   ├── bounding_boxes.txt
│   ├── classes.txt
│   ├── image_class_labels.txt
│   ├── images
│   ├── images.txt
│   ├── parts
│   ├── README
├── Car
│   ├── cars_test
│   ├── cars_train
│   ├── devkit
│   └── tfrecords
├── fgvc-aircraft-2013b
│   ├── data
│   ├── evaluation.m
│   ├── example_evaluation.m
│   ├── README.html
│   ├── README.md
│   ├── vl_argparse.m
│   ├── vl_pr.m
│   ├── vl_roc.m
│   └── vl_tpfp.m
├── dogs
│   ├── file_list.mat
│   ├── Images
│   ├── test_list.mat
│   └── train_list.mat
  • Prepare the ./data folder: generate file list txt (using ./utils/convert_data.py) and do soft link.
python utils/convert_data.py  --dataset_name bird --root_path .../Fine-grained/CUB_200_2011
├── data
│   ├── Aircraft -> /your_root_path/Fine-grained/fgvc-aircraft-2013b/data
│   ├── aircraft_test.txt
│   ├── aircraft_train.txt
│   ├── Bird -> /your_root_path/Fine-grained/CUB_200_2011
│   ├── bird_test.txt
│   ├── bird_train.txt
│   ├── Car -> /your_root_path/Fine-grained/Car
│   ├── car_test.txt
│   ├── car_train.txt
│   ├── Dog -> /your_root_path/Fine-grained/dogs
│   ├── dog_test.txt
│   └── dog_train.txt

Usage

  • Train
python train_bap.py train\
    --model-name inception \
    --batch-size 12 \
    --dataset car \
    --image-size 512 \
    --input-size 448 \
    --checkpoint-path checkpoint/car \
    --optim sgd \
    --scheduler step \
    --lr 0.001 \
    --momentum 0.9 \
    --weight-decay 1e-5 \
    --workers 4 \
    --parts 32 \
    --epochs 80 \
    --use-gpu \
    --multi-gpu \
    --gpu-ids 0,1 \

A simple way is to use sh train_bap.sh or run backgroud with logs using cmd nohup sh train_bap.sh 1>train.log 2>error.log &

  • Test
python train_bap.py test\
    --model-name inception \
    --batch-size 12 \
    --dataset car \
    --image-size 512 \
    --input-size 448 \
    --checkpoint-path checkpoint/car/model_best.pth.tar \
    --workers 4 \
    --parts 32 \
    --use-gpu \
    --multi-gpu \
    --gpu-ids 0,1 \

ws_dan_pytorch's People

Contributors

wvinzh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

ws_dan_pytorch's Issues

咨询

您好,使用您的代码在CUB数据集上达到了相似的精度
现在应用于自己的数据集上,但是训练和验证的精度一直处于非常低如Prec@1 0.199 Prec@5 1.558
您能给我一些建议吗
非常感谢!

关于output1 和output2的问题

用自己的数据训练,发现output1和output2值一样,甚至原始输入图像和经过attention裁剪后的图像是一样的,这是正常的吗?

About load bird.pth.tar

When I test the model bird.pth.tar, Missing key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight", .................

我使用仓库提供的bird预训练模型在2080ti*2 torch0.4.1 cuda9.2平台运行得到的refine acc是89.023,我自己用这个平台训练结果是88.989,和你说的89.3存在一定差距

image
image
代码改动

  • 只是把attention_crop_drop换成了你提供的attention_crop_drop2,因为mask在cpu上,
  • correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)是utils.py里面的

只是想确认一下提供在onedrive上的模型是不是在你的平台上时89.30,如果排除是平台问题,
我就可以在你的代码上重新实现基于WS-DAN的另一篇paper了https://github.com/aioz-ai/SAC
他也是用的WS-DAN的tf代码基础上加的内容

attention map

您好,我想请教您几个问题?
(1)在resnet50里, 代码实现中attemtion map 的数量为32,是每一个attention map 和原来的最后一个bottleneck里的卷积层特征相乘。我想问的是attention map是如何生成。
(2)能否采用像imporved bilinear pooling的方法改良整个代码。

some confusion

pooling_features = raw_features*100 , Why multiply raw_features by 100?

Center Loss

你这里center loss定义了,但是在训练的时候没加上去:
# define loss criterion = torch.nn.CrossEntropyLoss() if use_gpu: criterion = criterion.cuda()
这可能就是你离官方代码(569-572行)差一点的原因啊

cuda out of memory error

if forward twice, cuda out of memory ecountered! batch size is 512
_, _, output2 = model(img_drop)
_, _, output3 = model(img_crop)

why is that?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.