Coder Social home page Coder Social logo

wgcban / apt Goto Github PK

View Code? Open in Web Editor NEW
11.0 2.0 0.0 965 KB

PyTorch Implementation of Attention Prompt Tuning: Parameter-Efficient Adaptation of Pre-Trained Models for Action Recognition

Home Page: https://www.wgcban.com/research#h.wgch37qlgaku

License: Other

Python 93.24% Shell 6.76%
computer-vision tuning-parameters video-understanding hmdb51 something-something ucf-101 adapter-tuning linear-probing prompt-tuning

apt's Introduction

APT: Attention Prompt Tuning

A Parameter-Efficient Adaptation of Pre-Trained Models for Action Recognition ...

Wele Gedara Chaminda Bandara, Vishal M Patel
Johns Hopkins University

Accepted at FG'24

Paper (on ArXiv)

Overview of Proposed Method

Comparison of our Attention Prompt Tuning (APT) for videos action classification with other existing tuning methods: linear probing, adapter tuning, visual prompt tuning (VPT), and full fine-tuning.

Attention Prompt Tuning (APT) injects learnable prompts directly into the MHA unlike VPT.

Getting Started

Step 1: Conda Environment

Setup the virtual conda environment using the environment.yml:

conda env create -f environment.yml

Then activate the conda environment:

conda activate apt

Step 2: Download the VideoMAE Pre-trained Models:

We use VideoMAE pretrianed on Kinetics-400 dataset for our experiments.

The pre-trained models for ViT-Small and ViT-Base backbones can be downloaded from below links:

Method Extra Data Backbone Epoch #Frame Pre-train
VideoMAE no ViT-S 1600 16x5x3 checkpoint
VideoMAE no ViT-B 1600 16x5x3 checkpoint

If you need other pre-trained models please refer MODEL_ZOO.md.

Step 3: Download the datasets

We conduct experiments on three action recognition datasets: 1) UCF101 2) HMDB51 3) Something-Something-V2.

Please refer DATASETS.md for access to those links and pre-processing steps.

Step 4: Attention Prompt Tuning

We provide example scripts to run the attention prompt tuning on UCF101, HMDB51, and SSv2 datasets in scripts/ folder.

Inside scripts/ you can find two folders which corresponds to APT finetuning with ViT-Small and ViT-Base architectures.

To fine-tune with APT you just need to execute finetune.sh file -- which will launch the job with distributed training by

For example, to fine-tune ViT-Base on SSv2 with APT, you may run:

sh scripts/ssv2/vit_base/finetune.sh

The finetune.sh looks like this:

# APT on SSv2
OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5se-2_pl2_ps0_pe11_drop10'
DATA_PATH='datasets/ss2/list_ssv2/'
MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth'

NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \
    run_class_apt.py \
    --model vit_base_patch16_224 \
    --transfer_type prompt \
    --prompt_start 0 \
    --prompt_end 11 \
    --prompt_num_tokens 2 \
    --prompt_dropout 0.1 \
    --data_set SSV2 \
    --nb_classes 174 \
    --data_path ${DATA_PATH} \
    --finetune ${MODEL_PATH} \
    --log_dir ${OUTPUT_DIR} \
    --output_dir ${OUTPUT_DIR} \
    --batch_size 8 \
    --batch_size_val 8 \
    --num_sample 2 \
    --input_size 224 \
    --short_side_size 224 \
    --save_ckpt_freq 10 \
    --num_frames 16 \
    --opt adamw \
    --lr 0.05 \
    --weight_decay 0.00001 \
    --epochs 100 \
    --warmup_epochs 10 \
    --test_num_segment 2 \
    --test_num_crop 3 \
    --dist_eval \
    --pin_mem \
    --enable_deepspeed \
    --prompt_reparam \
    --is_aa \
    --aa rand-m4-n2-mstd0.2-inc1

Here,

  • OUTPUT_DIR: place where you wants to save the results (i.e., logs and checkpoints)
  • DATA_PATH: path to where the dataset is stored
  • MODEL_PATH: path to the downloaded videomae pre-trained model
  • specifiy thich gpus (gpu ids) you wants to use for finetuning in CUDA_VISIBLE_DEVICES=...
  • nproc_per_node is the number of gpus using for fine-tuning
  • model is the vit-base (vit_base_patch16_224) or vit-small (vit_small_patch16_224)
  • transfer_type specifies which finetuning method to use. 'random' means random initialization, 'end2end' means full end-to-end fine tuning, 'prompt' means APT (ours), 'linear' means linear probing
  • prompt_start refers to starting trasnformer block where you add attention prompts. 0 means you start adding learninable prompts from 1st transformer block in vit
  • prompt_end refers to ending trasformer block where you stop adding attention prompts. vit-base / vit-small has 12 transformer blocks. hence 11 here means you add prompts until last trasnformer block
  • data_set specifies the dataset
    • all the other parameters are hyperparamters related to apt fine-tuning.

✏️ Citation

If you think this project is helpful, please feel free to leave a star and cite our paper:

@misc{bandara2024attention,
      title={Attention Prompt Tuning: Parameter-efficient Adaptation of Pre-trained Models for Spatiotemporal Modeling}, 
      author={Wele Gedara Chaminda Bandara and Vishal M. Patel},
      year={2024},
      eprint={2403.06978},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

✏️ Disclaimer

This repocitory is built on top of VideoMAE: https://github.com/MCG-NJU/VideoMAE codebase and we approcite the authors of VideoMAE for making their codebase publically available.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.