Coder Social home page Coder Social logo

atmg's Introduction

Accept the Modality Gap: An Exploration in the Hyperbolic Space (CVPR 2024 Highlight) [Paper]

Sameera Ramasinghe, Violetta Shevchenko, Gil Avraham, Ajanthan Thalaiyasingam

Note: This codebase is heavily based on the MERU repo. We thank the MERU authors for their generous contributions.

Installation

git clone https://github.com/samgregoost/atmg.git
cd atmg
conda create -n atmg python=3.9 --yes
conda activate atmg

Install torch and torchvision following the instructions on pytorch.org. Then install the remaining dependencies, and this codebase as a dev package:

python -m pip install --pre timm
python -m pip install -r requirements.txt
python setup.py develop

Training data

The models are trained using the RedCaps dataset. Use the RedCaps downloader tool and follow the instructions to download the dataset and further organize as TAR files.

The dataset is expected relative to the project repository as follows:

atmg  # project repository.
└── datasets
    └── redcaps
        ├── shard_00000000.tar
        ├── shard_00000001.tar
        └── {any_name}.tar

This dataset format is general enough to support training with similar image-text datasets like Conceptual Captions (12M, 3M), if they are structured as TAR files described above.

Training the models

Train a ViT-small model use the below command:

python scripts/train.py --config configs/train_atmg_vit_s.py --num-gpus 8 --output-dir ./output

To change the other configurations, edit configs/train_atmg_vit_l.py. The VIT-small and VIT-base model configurations are inherited from the VIT-large config file.

Pretrained checkpoints

Our trained model checkpoints can be downloaded from the below links.

Image traversals

To perform image traversals run the below command

python scripts/image_traversals.py --image-path assets/taj_mahal.jpg \
    --checkpoint-path checkpoints/atmg_vit_s.pth --train-config configs/train_atmg_vit_s.py

Evaluate trained models

This codebase supports evaluation on 22 datasets on three types of tasks. See the instructions below.

  1. Zero-shot image classification:

Download and symlink the ImageNet dataset (Torchvision ImageFolder style) at ./datasets/eval/imagenet. The evaluation script will auto-download and cache all other 19 datasets in ./datasets/eval. Run the following command to evaluate ATMG ViT-small on 20 datasets:

python scripts/evaluate.py --config configs/eval_zero_shot_classification.py \
    --checkpoint-path checkpoints/atmg_vit_s.pth \
    --train-config configs/train_atmg_vit_s.py
  1. Zero-shot image and text retrieval:

Two datasets are supported, COCO captions and Flickr30k captions. Arrange their files in ./datasets/coco and ./datasets/flickr30k.

python scripts/evaluate.py --config configs/eval_zero_shot_retrieval.py \
    --checkpoint-path checkpoints/atmg_vit_s.pth \
    --train-config configs/train_atmg_vit_s.py

If you find our work useful, please consider citing as below.

@inproceedings{ramasinghe2024accept,
  title={Accept the modality gap: An exploration in the hyperbolic space},
  author={Ramasinghe, Sameera and Shevchenko, Violetta and Avraham, Gil and Thalaiyasingam, Ajanthan},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={27263--27272},
  year={2024}
}

atmg's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.