Coder Social home page Coder Social logo

cswin-transformer's Introduction

CSWin-Transformer, CVPR 2022

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows".

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution acc@1 #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.7 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.4 51.5 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 51.1 52.2 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

pretrained models and code could be found at segmentation

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

cswin-transformer's People

Contributors

jianminbao avatar lightdxy avatar microsoftopensource avatar youchenghuanxian avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cswin-transformer's Issues

Welcome update to OpenMMLab 2.0

Welcome update to OpenMMLab 2.0

I am Vansin, the technical operator of OpenMMLab. In September of last year, we announced the release of OpenMMLab 2.0 at the World Artificial Intelligence Conference in Shanghai. We invite you to upgrade your algorithm library to OpenMMLab 2.0 using MMEngine, which can be used for both research and commercial purposes. If you have any questions, please feel free to join us on the OpenMMLab Discord at https://discord.gg/amFNsyUBvm or add me on WeChat (van-sin) and I will invite you to the OpenMMLab WeChat group.

Here are the OpenMMLab 2.0 repos branches:

OpenMMLab 1.0 branch OpenMMLab 2.0 branch
MMEngine 0.x
MMCV 1.x 2.x
MMDetection 0.x 、1.x、2.x 3.x
MMAction2 0.x 1.x
MMClassification 0.x 1.x
MMSegmentation 0.x 1.x
MMDetection3D 0.x 1.x
MMEditing 0.x 1.x
MMPose 0.x 1.x
MMDeploy 0.x 1.x
MMTracking 0.x 1.x
MMOCR 0.x 1.x
MMRazor 0.x 1.x
MMSelfSup 0.x 1.x
MMRotate 1.x 1.x
MMYOLO 0.x

Attention: please create a new virtual environment for OpenMMLab 2.0.

How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

  1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
  2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?

Inference Strategy

  1. I notice that there is a "--tta" option in args. Do you use TTA in inference to derive the results reported in the paper?
  2. I notice that you specified tricky "model-ema-decay" (e.g. 0.99984 for CSwin-tiny and 0.99992 for CSwin-base), do you use ema-model for inference to derive the results in the paper?
  3. How do the two factors mentioned above impact the model performance?

about dataset cifar100

hello, can you provide script about how to store and run cifar100 dataset, thank you.

Pretrained settings for object detection

Hi, I'm impressed by your excellent work.

I have a question.

I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

The weight size of pretrained model Cswin-L is not the same with the script?

[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: merge3.conv.weight will not be loaded. Please double check and see if this is desired.
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: Shape of merge3.norm.bias in checkpoint is torch.Size([1152]), while shape of backbone.merge3.norm.bias in model is torch.Size([576]).
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: merge3.norm.bias will not be loaded. Please double check and see if this is desired.
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: Shape of merge3.norm.weight in checkpoint is torch.Size([1152]), while shape of backbone.merge3.norm.weight in model is torch.Size([576]).
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: merge3.norm.weight will not be loaded. Please double check and see if this is desired.
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: Shape of stage4.0.attns.0.get_v.bias in checkpoint is torch.Size([1152]), while shape of backbone.stage4.0.attns.0.get_v.bias in model is torch.Size([576]).
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: stage4.0.attns.0.get_v.bias will not be loaded. Please double check and see if this is desired.
[04/20 19:11:50] d2.checkpoint.c2_model_loading WARNING: Shape of stage4.0.attns.0.get_v.weight in checkpoint is torch.Size([1152, 1, 3, 3]), while shape of backbone.stage4.0.attns.0.get_v.weight in model is torch.Size([576, 1, 3, 3]).

Why the the weights in Cswin-L can not totally loaded?

CSWin significantly slower than Swin?

Greetings,

From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

Function of drop_rate, attn_drop_rate and drop_path_rate——drop_rate, attn_drop_rate and drop_path_rate应该设置多少,模型能提高map喃

1.Since drop_rate, attn_drop_rate and drop_path_rate are 0 by default, drop_path is not enabled.I want to know how much drop_path_rate , attn_drop_rate and drop_path_rate are set, and the effect of the model will be better.thanks!
由于 drop_rate, attn_drop_rate和drop_path_rate默认为0,未启用drop_path,想知道将drop_path_rate, attn_drop_rate和drop_path_rate 设置为多少,模型的效果会好一点()论文没有提到,源码默认为0)。谢谢!
2.The model compares Swin as a backbone on Mask R-CNN. I want to know whether the initial channel number (DIM) of Swin-T is 96 and that of CSwin-T is 64, that is, is CSwin-T configured in the detection network backbone in the following table?
Models #Dim #Blocks sw #heads #Param. FLOPs
CSWin-T 64 1,2,21,1 1,2,7,7 2,4,8,16 23M 4.3G
CSWin-S 64 2,4,32,2 1,2,7,7 2,4,8,16 35M 6.9G
CSWin-B 96 2,4,32,2 1,2,7,7 4,8,16,32 78M 15.0G
CSWin-L 144 2,4,32,2 1,2,7,7 6,12,24,48 173M 31.5G
模型在Mask R-CNN上作为Backbone对比了SWin,我想知道Swin-T的初始通道数(dim)是96,而CSWin-T的初始通道数(dim)是64吗,也就是说CSWin-T是下表配置在检测网络backbone中吗?
Models #Dim #Blocks sw #heads #Param. FLOPs
CSWin-T 64 1,2,21,1 1,2,7,7 2,4,8,16 23M 4.3G
CSWin-S 64 2,4,32,2 1,2,7,7 2,4,8,16 35M 6.9G
CSWin-B 96 2,4,32,2 1,2,7,7 4,8,16,32 78M 15.0G
CSWin-L 144 2,4,32,2 1,2,7,7 6,12,24,48 173M 31.5G

Using transfer to train over food101

Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True)
chk_path = './pretrained/cswin_tiny_224.pth'
load_checkpoint(model, chk_path)
model.reset_classifier(101, 'max')

These are some of the runs I tried
`
bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base
`

Is there something I'm missing or a proper way I should try this?

Thanks in advance for any help! :)

The results of downstream task by my realization are poor

There are some questions:

  1. the split size is still [1 2 7 7]?
  2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
  3. pading is right in my realization ?
    pad_l = pad_t = 0
    pad_r = (W_sp - W % W_sp) % W_sp
    pad_b = (H_sp - H % H_sp) % H_sp
    q = q.transpose(-2,-1).contiguous().view(B, H, W, C)
    k = q.transpose(-2,-1).contiguous().view(B, H, W, C)
    v = q.transpose(-2,-1).contiguous().view(B, H, W, C)
    if pad_r > 0 or pad_b > 0:
    q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b))
    k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b))
    v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
    _, Hp, Wp, _ = q.shape

About the patches_resolution of the segmentation model

Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

why gpu memory of cswin-tiny is twice than swin-tiny?

I have transformed cswin-tiny to object detection task, just adding some padding operation. When I used Faster-RCNN model, I found that gpu memory of cswin-tiny is twice than that of swin-tiny, is there something wrong I made?

about the setting of --use_chk

give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

CSWin-L for sota performance ade20k

Hi there~thank you for the great work. Do you have a training script for reproducing the 55.7 mIoU ADE20K result? I am eager to try it out and use CSWin to see if we can beat Swin in the ADE20K benchmark.

about the test result on imagenet-2012

Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

DEFAULT_CROP_SIZE = 0.9
scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
transform = transforms.Compose(
        [
            transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
            if image_size == 224
            else transforms.Resize(image_size, interpolation=3),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ]
    )

I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

Hi, very glad to see this new version of Swin-Trans. Could I have a question about using mixed-precision training

Dear author,

I notice that the repo recommends using apex mixed-precision for fine-tuning.
Then, how about learning from scratch on ImageNet-1k (should I also open the Apex mixed-precision training in this case)?
Previously, I found that mixed-precision could decrease the results for training CNNs on ImageNet if training from scratch.
Hence, I wonder whether mixed-precision training served as the default setting for the experiments of CSwins (or Swins).
Thank you so much!

It's was hard to fine-tuning on other dataset.

I was use this network trained on image defect classification task, and it was very hard train, and get low acc, but other model, like VIP model based on mlp architecture,or pure resnet50,those model is really easy to fine-tuning on my dataset. I also adjust my lr, weight decay, batch size, data aug policy(suit my data), change optimizer, but it was not help.

timm version

I was trying to train this model on GG Colab. At first, I installed the latest version of timm (0.9.12) and I encountered this error: cannot import name 'Dataset' from 'timm.data'

Then I reinstalled the timm version that was recommended in the readme.md file (0.3.4) and I saw another error:
File "/root/.local/lib/python3.10/site-packages/timm/models/layers/helpers.py", line 6, in <module> from torch._six import container_abcs ModuleNotFoundError: No module named 'torch._six'

Can you please suggest any ways to fix this? Thank you so much.

The label of Imagenet 22k

Could you provide the corresponding labels for Imagenet 22k? Looking forward to your reply. Thank you !

error

image
i get this error when training the model, do you know what it is? definitely not my gpu memory cause i tried with the smallest model

the memory is nearly twice that of swin transformer

Hi,Thanks for great jobs, I try to infer a image , I tried to use the segmentation model by cswin transformer small to test some pictures, and found that the memory is nearly twice that of swin transformer small. Is this normal?

The problem is shown in the figure

model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2],
split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval()
inp = torch.rand(1, 3, 224, 224).cuda()
outs = model(inp)
for out in outs:
print(out.shape)

RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

why?
image

Experiment setting for semantic segmentation

Hi, thank you for the code.
I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used?
Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.