Coder Social home page Coder Social logo

danny0559 / dat Goto Github PK

View Code? Open in Web Editor NEW

This project forked from leaplabthu/dat

0.0 0.0 0.0 13.63 MB

Repository of Vision Transformer with Deformable Attention (CVPR2022)

Home Page: https://arxiv.org/abs/2201.00520

License: Apache License 2.0

Shell 1.04% Python 98.96%

dat's Introduction

Vision Transformer with Deformable Attention

This repository contains the code for the paper Vision Transformer with Deformable Attention (CVPR2022) [arXiv][video][poster].

Introduction

Motivation

Motivation

(a) Vision Transformer(ViT) has proved its superiority over many tasks thanks to its large or even global receptive field. However, this global attention leads to excessive computational costs. (b) Swin Transformer proposes shifted window attention, which is a more efficient sparse attention mechanism with linear computation complexity. Nevertheless, this hand-crafted attention pattern is likely to drop important features outside one window, and shifting windows impedes the growth of the receptive field, limiting modeling the long-range dependencies. (c) DCN expands the receptive fields of the standard convolutions with the learned offsets for each different query. Howbeit, directly applying this technique to the Vision Transformer is non-trivial for the quadratic space complexity and the training difficulties. (d) Deformable Attention (DAT) is proposed to model the relations among tokens effectively under the guidance of the important regions in the feature maps. This flexible scheme enables the self-attention module to focus on relevant regions and capture more informative features.

Method

Deform_Attn

By learning several groups of offsets for the grid reference points, the deformed keys and values are sampled from these shifted locations. This deformable attention can capture the most informative regions in the image. On this basis, we present Deformable Attention Transformer (DAT), a general backbone model with deformable attention for both image classification and other dense prediction tasks.

Visualizations

Visualizations

Visualizations show the most important keys denotes in orange circles, where larger circles indicates higher attention scores. That the important keys cover the main parts of the objects demonstrates the effectiveness of DAT.

Dependencies

  • NVIDIA GPU + CUDA 11.3
  • Python 3.9 (>=3.6, recommend to use Anaconda)
  • cudatoolkit == 11.3.1
  • PyTorch == 1.11.0
  • torchvision == 0.12.0
  • numpy
  • timm == 0.5.4
  • einops
  • PyYAML
  • yacs
  • termcolor

Evaluate Pretrained Models

We provide the pretrained models in the tiny, small, and base versions of DAT, as listed below.

model resolution acc@1 config pretrained weights
DAT-Tiny 224x224 82.0 config GoogleDrive / TsinghuaCloud
DAT-Small 224x224 83.7 config GoogleDrive / TsinghuaCloud
DAT-Base 224x224 84.0 config GoogleDrive / TsinghuaCloud
DAT-Base 384x384 84.8 config GoogleDrive / TsinghuaCloud

To evaluate one model, please download the pretrained weights to your local machine and run the script evaluate.sh as follow.

bash evaluate.sh <gpu_nums> <path-to-config> <path-to-pretrained-weights>

E.g., suppose evaluating the DAT-Tiny model (dat_tiny_in1k_224.pth) with 8 GPUs, the command should be:

bash evaluate.sh 8 configs/dat_tiny.yaml dat_tiny_in1k_224.pth

And the evaluation result should give:

[2022-06-07 04:08:50 dat_tiny] (main.py 288): INFO  * Acc@1 82.034 Acc@5 95.850
[2022-06-07 04:08:50 dat_tiny] (main.py 150): INFO Accuracy of the network on the 50000 test images: 82.0%

Outputs of the other models are:

[2022-06-07 04:19:42 dat_small] (main.py 288): INFO  * Acc@1 83.686 Acc@5 96.392
[2022-06-07 04:19:42 dat_small] (main.py 150): INFO Accuracy of the network on the 50000 test images: 83.7%

[2022-06-07 04:24:35 dat_base] (main.py 288): INFO  * Acc@1 84.028 Acc@5 96.686
[2022-06-07 04:24:35 dat_base] (main.py 150): INFO Accuracy of the network on the 50000 test images: 84.0%

[2022-06-07 06:43:07 dat_base_384] (main.py 288): INFO  * Acc@1 84.754 Acc@5 96.982
[2022-06-07 06:43:07 dat_base_384] (main.py 150): INFO Accuracy of the network on the 50000 test images: 84.8%

Train Models from Scratch

To train a model from scratch, we provide a simple script train.sh. E.g, to train a model with 8 GPUs on a single node, you can use this command:

bash train.sh 8 <path-to-config> <experiment-tag>

We also provide a training script train_slurm.sh for training models on multiple machines with a larger batch-size like 4096.

bash train_slurm.sh 32 <path-to-config> <slurm-job-name>

Remember to change the <path-to-imagenet> in the script files to your own ImageNet directory.

TODO

  • Classification pretrained models.
  • Object Detection codebase & models.
  • Semantic Segmentation codebase & models.
  • CUDA operators to accelerate sampling operations.

Acknowledgements

This code is developed on the top of Swin Transformer, we thank to their efficient and neat codebase. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd.

Citation

If you find our work is useful in your research, please consider citing:

@InProceedings{Xia_2022_CVPR,
    author    = {Xia, Zhuofan and Pan, Xuran and Song, Shiji and Li, Li Erran and Huang, Gao},
    title     = {Vision Transformer With Deformable Attention},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {4794-4803}
}

Contact

If you have any questions or concerns, please send mail to [email protected].

dat's People

Contributors

vladimir2506 avatar leaplabthu avatar panxuran avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.