Accept the Modality Gap: An Exploration in the Hyperbolic Space (CVPR 2024 Highlight) [Paper
]
Sameera Ramasinghe, Violetta Shevchenko, Gil Avraham, Ajanthan Thalaiyasingam
Note: This codebase is heavily based on the MERU repo. We thank the MERU authors for their generous contributions.
git clone https://github.com/samgregoost/atmg.git
cd atmg
conda create -n atmg python=3.9 --yes
conda activate atmg
Install torch and torchvision following the instructions on pytorch.org. Then install the remaining dependencies, and this codebase as a dev package:
python -m pip install --pre timm
python -m pip install -r requirements.txt
python setup.py develop
The models are trained using the RedCaps dataset. Use the RedCaps downloader tool and follow the instructions to download the dataset and further organize as TAR files.
The dataset is expected relative to the project repository as follows:
atmg # project repository.
└── datasets
└── redcaps
├── shard_00000000.tar
├── shard_00000001.tar
└── {any_name}.tar
This dataset format is general enough to support training with similar image-text datasets like Conceptual Captions (12M, 3M), if they are structured as TAR files described above.
Train a ViT-small model use the below command:
python scripts/train.py --config configs/train_atmg_vit_s.py --num-gpus 8 --output-dir ./output
To change the other configurations, edit configs/train_atmg_vit_l.py
. The VIT-small and VIT-base model configurations are inherited from the VIT-large config file.
Our trained model checkpoints can be downloaded from the below links.
- Model: ATMG ViT-base and config: train_atmg_vit_b.py
- Model: ATMG ViT-small and config: train_atmg_vit_s.py
To perform image traversals run the below command
python scripts/image_traversals.py --image-path assets/taj_mahal.jpg \
--checkpoint-path checkpoints/atmg_vit_s.pth --train-config configs/train_atmg_vit_s.py
This codebase supports evaluation on 22 datasets on three types of tasks. See the instructions below.
- Zero-shot image classification:
Download and symlink the ImageNet dataset (Torchvision ImageFolder
style) at ./datasets/eval/imagenet
. The evaluation script will auto-download and cache
all other 19 datasets in ./datasets/eval
.
Run the following command to evaluate ATMG ViT-small on 20 datasets:
python scripts/evaluate.py --config configs/eval_zero_shot_classification.py \
--checkpoint-path checkpoints/atmg_vit_s.pth \
--train-config configs/train_atmg_vit_s.py
- Zero-shot image and text retrieval:
Two datasets are supported, COCO captions and Flickr30k captions. Arrange their files in ./datasets/coco
and ./datasets/flickr30k
.
python scripts/evaluate.py --config configs/eval_zero_shot_retrieval.py \
--checkpoint-path checkpoints/atmg_vit_s.pth \
--train-config configs/train_atmg_vit_s.py
If you find our work useful, please consider citing as below.
@inproceedings{ramasinghe2024accept,
title={Accept the modality gap: An exploration in the hyperbolic space},
author={Ramasinghe, Sameera and Shevchenko, Violetta and Avraham, Gil and Thalaiyasingam, Ajanthan},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={27263--27272},
year={2024}
}