Coder Social home page Coder Social logo

stop's Introduction

Stochastic positional embeddings improve masked image modeling (ICML 2024)

This repository is the official implementation of StoP.

Introduction

StoP

Given a partial image of a dog, can you precisely determine the location of its tail? Existing Masked Image Modeling (MIM) models like MAE and I-JEPA predict tokens deterministically and do not model location uncertainties (a), we propose to predict the target (masked tokens) in stochastic positions (StoP) which prevents overfitting to locations features. StoP leads to improved MIM performance on downstream tasks, including linear probing on ImageNet (b).

Installation

Please follow the installation instruction from the I-JEPA repo.

Usage

Pretraining on ImageNet

The commands for pretraining I-JEPA + StoP. In the original setting, all models were trained using 4 V100 GPU nodes. ViT-B/L were trained with float32 while ViT-H is trained on half precision (float16).

Training command:

torchrun --nnodes=4 --nproc-per-node=8 --node_rank=<node_rank 0-3> --master_addr=<master_addr> --master_port=<master_port> --backend=nccl main.py --fname configs/pretrain/vit-b16.yaml
  • Prior to running, set the path to imagenet (image_folder, root_folder) in the config file.
  • to train ViT-L/ViT-H change vit-b16 to vit-l16/vit-h16.
  • We --node_rank as 0 on the first node. On other nodes, run the same command with --node_rank=1,...,3 respectively. --master_addr is set as the ip of the node 0.

Linear probing eval on ImageNet with 1% of labels (fast)

python logistic_eval.py \
  --subset-path imagenet_subsets1/1percent.txt \
  --root-path /path/to/datasets --image-folder imagenet_folder/ \
  --device cuda:0 \
  --pretrained /path/to/checkpoint/folder \
  --fname checkpoint_name.pth.tar \
  --model-name deit_base \
  --patch-size 16 \
  --penalty l2 \
  --lambd 0.0025

Linear probing classification eval using VISSL

To perform linear probing on ImageNet, you can follow instruction from VISSL. Alternatiely, we provide a bash script to convert checkpoints to vissl format and launch experiments on 8 V100 machines each with 8 gpus on SLURM:

bash bash/in1k_eval_vissl.sh <output_dir> <checkpoint_path> <dataset_root> <arch>
  • arch is vitb, vitl or vitb.
  • checkpoint path is the full path to the checkpoint saved durning training
  • dataset_root is the path to imagenet folder.

Pretrained Models Zoo

Arch. Dataset Epochs Checkpoint
ViT-B/16 ImageNet 600 link
ViT-L/16 ImageNet 600 link
ViT-H/16 ImageNet 300 link

Differences compared to the official I-JEPA implementation

  • I-JEPA was trained using bfloat16 which is supported in new nvidia gpus (e.g, A100, H100) and beneficial to stabilize training. Here we used older gpus hence used float32 for the smaller models ViT-B/L and float16 for ViT-H.
  • For ViT-H training, we deviated from the cosine LR schedule after 250 epochs and continued training with fixed low learning rate to push performance.
  • StoP relies on a previous internal implementation of I-JEPA which utilized additional image augmentations compared to the official repo. The main difference is the use of hflip, gaussian blur and random grayscale during training. In the ablation experiments (Figures 3-4, Tables 5-6 & 8) we compare different positional embeddings using this same set of augmentations.

Acknowledgments

The codebase relies on the implementation of I-JEPA.

Citation

If you found this code helpful, feel free to cite our work:

@inproceedings{barstochastic,
  title={Stochastic positional embeddings improve masked image modeling},
  author={Bar, Amir and Bordes, Florian and Shocher, Assaf and Assran, Mido and Vincent, Pascal and Ballas, Nicolas and Darrell, Trevor and Globerson, Amir and LeCun, Yann},
  booktitle={Forty-first International Conference on Machine Learning}
}```

stop's People

Contributors

amirbar avatar

Stargazers

 avatar  avatar Ma Jiajian avatar Guillaume Letellier avatar Denisa Roberts avatar  avatar Jihwan Eom avatar

Watchers

Denisa Roberts avatar  avatar  avatar

stop's Issues

Code release

Thank you for sharing interesting job!

Could you share the rough plan to release?

Best,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.