Coder Social home page Coder Social logo

groupmixformer's Introduction

Official PyTorch implementation of GroupMixFormer for the paper:

Image Description

🐱 Abstract

TL; DR:

We introduce GroupMixFormer, which employs Group-Mix Attention (GMA) as an advanced substitute for conventional self-attention. GMA is designed to concurrently capture correlations between tokens as well as between different groups of tokens, accommodating diverse group sizes.

Full abstract

Vision Transformers (ViTs) have shown to enhance visual recognition through modeling long-range dependencies with multi-head self-attentions (MHSA), which is typically formulated as Query-Key-Value computation. However, the attention map generated from the Query and Key only captures token-to-token correlations at one single granularity. In this paper, we argue that self-attention should have a more comprehensive mechanism to capture correlations among tokens and groups (i.e., multiple adjacent tokens) for higher representational capacity. Thereby, we propose Group-Mix Attention (GMA) as an advanced replacement for traditional self-attention, which can simultaneously capture token-to-token, token-to-group, and group-to-group correlations with various group sizes. To this end, GMA splits the Query, Key, and Value into segments uniformly and performs different group aggregations to generate group proxies. The attention map is computed based on the mixtures of tokens and group proxies and used to re-combine the tokens and groups in Value. Based on GMA, we introduce a powerful backbone, namely GroupMixFormer, which achieves state-of-the-art performance in image classification, object detection, and semantic segmentation with fewer parameters than existing models. For instance, GroupMixFormer-L (with 70.3M parameters and 384^2 input) attains 86.2% Top-1 accuracy on ImageNet-1K without external data, while GroupMixFormer-B (with 45.8M parameters) attains 51.2% mIoU on ADE20K.

🚩 Updates

New features

  • βœ… Oct. 18, 2023. Release the training code.
  • βœ… Oct. 18, 2023. Release the inference code.
  • βœ… Oct. 18, 2023. Release the pretrained models for classification.

Catalog

  • ImageNet-1K Training Code
  • Downstream Transfer (Detection, Segmentation) Code

βš™οΈ Usage

1 - Installation

  • Create an new conda virtual environment
conda create -n groupmixformer python=3.8 -y
conda activate groupmixformer
  • Install Pytorch>=1.8.0, torchvision>=0.9.0 following official instructions. For example:
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
  • Clone this repo and install required packages:
git clone https://github.com/AILab-CVC/GroupMixFormer.git
pip install timm==0.4.12 tensorboardX six tensorboard ipdb yacs tqdm fvcore
  • The results in the paper are produced with torch==1.8.0+cu111 torchvision==0.9.0+cu111 timm==0.4.12.

  • Other dependicies: mmdetection and mmsegmentation are optional for downstream transfer.

2 - Data Preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

β”‚path/to/imagenet/
β”œβ”€β”€train/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ n01440764_10026.JPEG
β”‚  β”‚   β”œβ”€β”€ n01440764_10027.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......
β”œβ”€β”€val/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00000293.JPEG
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00002138.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......

3 - Trianing Scripts

To train GroupMixFormer-Small on ImageNet-1k on a single node with 8 gpus for 300 epochs, please run:

python3 -m torch.distributed.launch --nproc_per_node 8 --nnodes 1 --use_env train.py \
  --data-path <Your data path> \
  --batch-size 64 \
  --output <Your target output path> \
  --cfg ./configs/groupmixformer_small.yaml \
  --model-type groupmixformer \
  --model-file groupmixformer.py \
  --tag groupmixformer_small

or you can simply run the following script:

bash launch_scripts/run_train.sh

For multi-node training, please refer to the code: multi_machine_start.py

4 - Inference Scripts

To eval GroupMixFormer-Small on ImageNet-1k on a single node, please identify the path of pretrained weight and run:

CUDA_VISIBLE_DEVICES=1 OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node 1 --nnodes 1 --use_env test.py \
  --data-path <Your data path> \
  --batch-size 64 \
  --output <Your target output path> \
  --cfg ./configs/groupmixformer_small.yaml \
  --model-type groupmixformer \
  --model-file groupmixformer.py \
  --tag groupmixformer_small

or you can simply run the following script:

bash launch_scripts/run_eval.sh

This should give

* Acc@1 83.400 Acc@5 96.464

⏬ Model Zoo

We provide GroupMixFormer models pretrained on ImageNet 2012. You can download the corresponding pretrained and move it to ./pretrained folder.

name resolution acc@1 #params FLOPs model - configs
GroupMixFormer-M 224x224 79.6 5.7M 1.4G model - configs
GroupMixFormer-T 224x224 82.6 11.0M 3.7G model - configs
GroupMixFormer-S 224x224 83.4 22.4M 5.2G model - configs
GroupMixFormer-B 224x224 84.7 45.8M 17.6G model - configs
GroupMixFormer-L 224x224 85.0 70.3M 36.1G model - configs

πŸ€— Acknowledgement

This repository is built using the timm library, DeiT and Swin repositories.

πŸ—œοΈ License

This project is released under the MIT license. Please see the LICENSE file for more information.

πŸ“– Citation

If you find this repository helpful, please consider citing:

@Article{xxx
}

groupmixformer's People

Contributors

chongjiange avatar dingxiaoh avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.