Coder Social home page Coder Social logo

templeblock / swiftformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from amshaker/swiftformer

0.0 0.0 0.0 8.36 MB

SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications

Shell 2.79% Python 97.21%

swiftformer's Introduction

SwiftFormer

SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications

Abdelrahman Shaker, Muhammad Maaz, Hanoona Rasheed, Salman Khan, Ming-Hsuan Yang, and Fahad Shahbaz Khan

paper

๐Ÿš€ News

  • (Mar 27, 2023): Classification training and evaluation codes along with pre-trained models are released.


Comparison of our SwiftFormer Models with state-of-the-art on ImgeNet-1K. The latency is measured on iPhone 14 Neural Engine (iOS 16).


Comparison with different self-attention modules. (a) is a typical self-attention. (b) is the transpose self-attention, where the self-attention operation is applied across channel feature dimensions (dร—d) instead of the spatial dimension (nร—n). (c) is the separable self-attention of MobileViT-v2, it uses element-wise operations to compute the context vector from the interactions of Q and K matrices. Then, the context vector is multiplied by V matrix to produce the final output. (d) Our proposed efficient additive self-attention. Here, the query matrix is multiplied by learnable weights and pooled to produce global queries. Then, the matrix K is element-wise multiplied by the broadcasted global queries, resulting the global context representation.

Abstract Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8~ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.

Classification on ImageNet-1K

Models

Model Top-1 accuracy #params GMACs Latency Ckpt CoreML
SwiftFormer-XS 75.7% 3.5M 0.4G 0.7ms XS XS
SwiftFormer-S 78.5% 6.1M 1.0G 0.8ms S S
SwiftFormer-L1 80.9% 12.1M 1.6G 1.1ms L1 L1
SwiftFormer-L3 83.0% 28.5M 4.0G 1.9ms L3 L3

Detection and Segmentation Qualitative Results



Latency Measurement

The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from XCode 14.

ImageNet

Prerequisites

conda virtual environment is recommended.

conda create --name=swiftformer python=3.9
conda activate swiftformer

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install timm

Data preparation

Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the train folder and val folder respectively:

|-- /path/to/imagenet/
    |-- train
    |-- val

Single machine multi-GPU training

We provide training script for all models in dist_train.sh using PyTorch distributed data parallel (DDP).

To train SwiftFormer models on an 8-GPU machine:

sh dist_train.sh /path/to/imagenet 8

Note: specify which model command you want to run in the script. To reproduce the results of the paper, use 16-GPU machine with batch-size of 128 or 8-GPU machine with batch size of 256. Auto Augmentation, CutMix, MixUp are disabled for SwiftFormer-XS only.

Multi-node training

On a Slurm-managed cluster, multi-node training can be launched as

sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS

Note: specify slurm specific paramters in slurm_train.sh script.

Testing

We provide an example test script dist_test.sh using PyTorch distributed data parallel (DDP). For example, to test SwiftFormer-XS on an 8-GPU machine:

sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth

Citation

if you use our work, please consider citing us:

@article{Shaker2023SwiftFormer,
  title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
  author={Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
  journal={arXiv preprint arXiv:2303.15446},
  year={2023}
}

Contact:

If you have any question, please create an issue on this repository or contact at [email protected].

Acknowledgement

Our code base is based on LeViT and EfficientFormer repositories. We thank authors for their open-source implementation.

Our Related Works

  • EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications, CADL'22, ECCV. Paper | Code.

swiftformer's People

Contributors

amshaker avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.