Coder Social home page Coder Social logo

csra's Introduction

CSRA

This is the official code of ICCV 2021 paper:
Residual Attention: A Simple But Effective Method for Multi-Label Recoginition

attention

Demo, Train and Validation code have been released! (including VIT on Wider-Attribute)

This package is developed by Mr. Ke Zhu (http://www.lamda.nju.edu.cn/zhuk/) and we have just finished the implementation code of ViT models. If you have any question about the code, please feel free to contact Mr. Ke Zhu ([email protected]). The package is free for academic usage. You can run it at your own risk. For other purposes, please contact Prof. Jianxin Wu (mail to [email protected]).

Requirements

  • Python 3.7
  • pytorch 1.6
  • torchvision 0.7.0
  • pycocotools 2.0
  • tqdm 4.49.0, pillow 7.2.0

Dataset

We expect VOC2007, COCO2014 and Wider-Attribute dataset to have the following structure:

Dataset/
|-- VOCdevkit/
|---- VOC2007/
|------ JPEGImages/
|------ Annotations/
|------ ImageSets/
......
|-- COCO2014/
|---- annotations/
|---- images/
|------ train2014/
|------ val2014/
......
|-- WIDER/
|---- Annotations/
|------ wider_attribute_test.json
|------ wider_attribute_trainval.json
|---- Image/
|------ train/
|------ val/
|------ test/
...

Then directly run the following command to generate json file (for implementation) of these datasets.

python utils/prepare/prepare_voc.py  --data_path  Dataset/VOCdevkit
python utils/prepare/prepare_coco.py --data_path  Dataset/COCO2014
python utils/prepare/prepare_wider.py --data_path Dataset/WIDER

which will automatically result in annotation json files in ./data/voc07, ./data/coco and ./data/wider

Demo

We provide prediction demos of our models. The demo images (picked from VCO2007) have already been put into ./utils/demo_images/, you can simply run demo.py by using our CSRA models pretrained on VOC2007:

CUDA_VISIBLE_DEVICES=0 python demo.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from OUR_VOC_PRETRAINED.pth --img_dir utils/demo_images

which will output like this:

utils/demo_images/000001.jpg prediction: dog,person,
utils/demo_images/000004.jpg prediction: car,
utils/demo_images/000002.jpg prediction: train,
...

Validation

We provide pretrained models on Google Drive for validation. ResNet101 trained on ImageNet with CutMix augmentation can be downloaded here.

Dataset Backbone Head nums mAP(%) Resolution Download
VOC2007 ResNet-101 1 94.7 448x448 download
VOC2007 ResNet-cut 1 95.2 448x448 download
VOC2007 (extra) ResNet-cut 1 96.8 448x448 download
COCO ResNet-101 4 83.3 448x448 download
COCO ResNet-cut 6 85.6 448x448 download
COCO VIT_L16_224 8 86.5 448x448 download
COCO VIT_L16_224* 8 86.9 448x448 download
Wider VIT_B16_224 1 89.0 224x224 download
Wider VIT_L16_224 1 90.2 224x224 download

For voc2007, run the following validation example:

CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20  --load_from MODEL.pth

For coco2014, run the following validation example:

CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 4 --lam 0.5 --dataset coco --num_cls 80  --load_from MODEL.pth

For wider attribute with ViT models, run the following

CUDA_VISIBLE_DEVICES=0 python val.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14  --load_from ViT_B16_MODEL.pth
CUDA_VISIBLE_DEVICES=0 python val.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14  --load_from ViT_L16_MODEL.pth

To provide pretrained VIT models on Wider-Attribute dataset, we retrain them recently, which has a slightly different performance (~0.1%mAP) from what has been presented in our paper. The structure of the VIT models is the initial VIT version (An image is worth 16x16 words: Transformers for image recognition at scale, link) and the implementation code of the VIT models is derived from http://github.com/rwightman/pytorch-image-models/.

Training

VOC2007

You can run either of these two lines below

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20
CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --cutmix CutMix_ResNet101.pth

Note that the first command uses the Official ResNet-101 backbone while the second command uses the ResNet-101 pretrained on ImageNet with CutMix augmentation link (which is supposed to gain better performance).

MS-COCO

run the ResNet-101 with 4 heads

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.5 --dataset coco --num_cls 80

run the ResNet-101 (pretrained with CutMix) with 6 heads

CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.4 --dataset coco --num_cls 80 --cutmix CutMix_ResNet101.pth

You can feel free to adjust the hyper-parameters such as number of attention heads (--num_heads), or the Lambda (--lam). Still, the default values of them in the above command are supposed to be the best.

Wider-Attribute

run the VIT_B16_224 with 1 heads

CUDA_VISIBLE_DEVICES=0 python main.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14

run the VIT_L16_224 with 1 heads

CUDA_VISIBLE_DEVICES=0,1 python main.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14

Note that the VIT_L16_224 model consume larger GPU space, so we use 2 GPUs to train them.

Notice

To avoid confusion, please note the 4 lines of code in Figure 1 (in paper) is only used in test stage (without training), which is our motivation. When our model is end-to-end training and testing, multi-head-attention (H=1, H=2, H=4, etc.) is used with different T values. Also, when H=1 and T=infty, the implementation code of multi-head-attention is exactly the same with Figure 1.

We didn't use any new augmentation such as Autoaugment, RandAugment in our ResNet series models.

Acknowledgement

We thank Lin Sui (http://www.lamda.nju.edu.cn/suil/) for his initial contribution to this project.

csra's People

Contributors

kevinz-code avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

csra's Issues

Is 'normalized' classifier necessary?

Thanks for your code. I just wonder the reason why we need to normalize the 'classifier' before the 'flatten' op. Does it perform bettern than that without normalizing?
Thank you for your explanation.

Details about baseline resnet-101 in paper

Hi, thanks for your excellent work! But I'm confused of the detail about baseline-model settings in your paper.

Take training resnet-101 without cutmix on coco2014 as an example:

With the following training configurations as baseline setting, I get 81.3 mAP after 7 epochs (30 in total, still in training process...), which is much higher than that in your paper (79.4 mAP).
python main.py --num_heads 4 --lam 0 --dataset coco --num_cls 80 --checkpoint coco14/resnet101

So, what is the correct settings to reproduce the baseline result as in your paper? Thanks again.

Global feature vector

Residual Attention: A Simple but Effective Method for Multi-Label Recognition
According to the paper, the base_logit (denoted as g in the paper) should be computing the global feature vector by averaging the features over all spatial locations. This is stated in the equation:

$$ \mathbf{g}=\frac{1}{49} \sum_{k=1}^{49} \mathbf{x}_{k} $$

Here, xk​ represents the feature at location k, and we sum over all locations (49 in this case) and then take the average. This operation is class-agnostic, meaning it’s not specific to any class and is the same for all classes. The global feature vector g represents the overall content of the image, irrespective of specific classes. It serves as a baseline representation of the image content.

In my Implementation i have this:

     def forward(self, x):
        B, _, H, W = x.size()  # batch size, _, height, width

        # Compute class-specific attention scores
        logits = self.classifier(x)  # size: (B, C, H, W)
        logits = logits.view(B, self.C, -1)  # size: (B, C, H*W)

        # Compute class-specific feature vectors
        x_flatten = x.view(B, self.d, -1)  # size: (B, d, H*W)

        # Compute global feature vector
        g = torch.mean(x_flatten, dim=2)  # size: (B, d) 

I am computing base_logit (or g) as per the paper’s method. The original implementation seems to be computing something different for base_logit, which doesn’t align with the paper’s description. It’s computing the average class-specific score for each class across all spatial locations, which is not what g represents according to the paper.

This is the original implementation:

      def forward(self, x):
         # x (B d H W)
         # normalize classifier
         # score (B C HxW)
         score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
         score = score.flatten(2)
         base_logit = torch.mean(score, dim=2)  # size: (B, C) 

Is there a reason for this?

Transformer on WiderAttribute predicition

@Kevinz-code thanks for providing the source code and great work. I have few queries which is mentioned below

  1. The vit model predicts the attributes of the person right like male,trousers ... what is the accuracy your getting on this
  2. I looked into the VIT implementation and compared it with "rwightman " implementation , you have used the same implementation the only changes is in the classifier part ie MHA-->CSRA part right ??
  3. Is CSRA customed to only VIT transformer can it be implemented with other transformers like CrossVIT
  4. The size of the custom image dataset which i have is of range from 8056 to 124128 should i change the patchsize from 16 to 8/4 since resizing the small resolution creates the pixelation blur effect
  5. Didi u keep the same training params for all the 3 training models

Please do share your thoughts Thansk in adavance

MobileNet; input size

Hi kevinz! As far as MobileNet is concerned, what is the input size on different datasets?
in MS-COCO
in VOC2007
in WIDER
in ImageNet
图片

Partial-label

Can I using this repo for dataset with Partial-label?

Cross Validation

Hi,

I have two questions:

First : Why didn't you use K Fold Cross Validation?
Second : What is the reason use different learning rate for classifier? Is it for faster convergence?

I am trying to adapt CSRA to EfficientNetB3 on my multi-label dataset. Although I try various head and lambda numbers, I am getting worse results according to baseline model. What is your opinion? Is there also something different to try?

Also there is class imbalance in my dataset. Is there need to make data augmentation to prevent class imbalance? Is CSRA a method affected by data augmentation?

Thanks

How to load your pre-trained model and train on my dataset?

Hi,

thank you so much for your great work. I'm doing a project with multi-label classification so I wonder how I can apply your pretrained model for image feature extraction? what I need is to extract feature of an image. Could you please give me some hints?

Best regards,
Hui

Code for model prediction

Thank you for your excellent work. It has benefited me a lot。Is there any code about prediction??

Load Model İssue?

When I try to load ''vit_L16_224_coco_head8_86.5.pth'' model in val.py, I get following error.

Error(s) in loading state_dict for VIT_CSRA:
Missing key(s) in state_dict: "classifier.multi_head.0.head.weight".
Unexpected key(s) in state_dict: "head1.weight", "head1.bias", "head2.weight", "head2.bias", "head3.weight", "head3.bias", "head4.weight", "head4.bias", "head5.weight", "head5.bias", "head6.weight", "head6.bias", "head7.weight", "head7.bias", "head8.weight", "head8.bias", "head.weight", "head.bias".
size mismatch for pos_embed: copying a param with shape torch.Size([1, 785, 1024]) from checkpoint, the shape in current model is torch.Size([1, 197, 1024]).
File "C:\Users\osivaz61\Desktop\projects\python\retina\diseaseDetection\CSRA-master\val.py", line 79, in main
model.load_state_dict(torch.load(args.load_from))
File "C:\Users\osivaz61\Desktop\projects\python\retina\diseaseDetection\CSRA-master\val.py", line 97, in
main()

As far as I understand the vit_csra.py file is not updated. Can you share the updated code?

Thanks

visualize function

i wanna run file visualize.py but cam_all parameter is not defined. how can i run this file to visualize?

The problems of val.py?

Using your provided val model to val.py ,why?
main.py including test.file,the test result means the real results or the test results was produced by val.py?
I am waiting for your reply

About the pretrained vit.

Hi Kevinz, thanks for your awesome work. I want to know whether you plan to release the weights of VIT that finetuned on the MSCOCO dataset?

Some problems about vision transformer

Hello, by combining the code and your paper, I have the following questions(about vit_ csra):

In the code, the class token is not used in the input of the last CSRA module, so why set the class token in the code in "VIT_CSRA".
Has the last MLP head used for classification in the vision transformer been deleted directly?

some questions about val.py

Thanks for sharing your code,I have som questions about your project.
In val.py, the definitionation of follows was empty.
parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str)

And I cannot find the saved model coding about this.
How to use the val.py in your project? and explansion the model saved path clearly?

Is the code consistent with the description in the paper?

According to formula 5 and formula 6 in the paper, the class-specific residual attention (CSRA) feature f should be sent to
the classifier to obtain the final logits, but in your code, you use the f as the final logits, what's the difference?

Question about Attention Image or Heatmap Generation

Hi Kevinz, thanks for your awesome work. I'd like to do a visual analysis to get a better understanding of the CSRA. Could you please give me some advice on how to visualize the attention score (or heatmap, attention image)? Thank you very much!

MobileNet implementation of CSRA

I'm trying to implement CSRA using MobileNet as the backbone, but I'm running into some troubles. This is kind of related to #5.
First of all, from the paper it was not clear to me whether CSRA is to be applied before, after or instead of the classifier.

Now, I have a question: Which version of MobileNet was CSRA implemented into? In my case, I'm trying to use MobileNetV3Large It's stated in the paper it's MobileNetV2

In my use case, I would like to use MobileNetV3 classifcation head, except with a different number of target classes. Where is CSRA supposed to be placed?

This is the structure of the MobileNetV3 classifier:
WhatsApp Image 2022-05-19 at 10 18 23 AM

Is the CSRA supposed to replace the Avg Pool on the (7,7,960) tensor? to replace the 1x1 Conv after the (1,1,1280) tensor? To take place after the last 1x1 Conv?

I think most of the confusion comes from Fig 1 and Fig 2 in the CSRA paper.

  • in Fig 1, the output of the backbone is run through the classifier, then through CSRA. It is stated that Fig 1 is a special case of CSRA, but it still remains confusing.

  • In Fig 2, f seems to act directly as the classes scores, while the text previous to Eq 6 states "Finally, all these class-specific feature vectors are sent to the classifier to obtain the final logits". It is not clear in Fig 2 that the result of the CSRA module is sent to the classifier, AND it brings more confusion to the matter of where is the CSRA module supposed to be placed

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.