Coder Social home page Coder Social logo

clip-training's Introduction

CLIP training

This repository contains code to train CLIP on MS-COCO captions. Can be easily modified to train on other multi-modal datasets (OpenImages, Conceptual captions, ...).

Requirements

To setup environment

# create new env clip_train
$ conda create -n clip_train python=3.8.5

# activate clip_train
$ conda activate clip_train

# install pytorch, torchvision
$ conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch

# install other dependencies
$ pip install -r requirements.txt

Training

Preparing training dataset

MS-COCO training set images and their captions are used for training the CLIP model. To download the dataset :

# create directory in data/
$ mkdir data/mscoco

# download images
$ wget http://images.cocodataset.org/zips/train2017.zip -O data/mscoco/train2017.zip
$ unzip data/mscoco/train2017.zip -d data/mscoco


# download annotations 
$ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O data/mscoco/annotations_trainval2017.zip
$ unzip data/mscoco/annotations_trainval2017.zip -d data/mscoco

To check and update training parameters, model config and dataset paths please see the following config files :

trainer/train_config.yaml   # training parameters
model/model_config.yaml     # CLIP model config
dataloader/data_config.yaml # training dataset path

To train :

Take dataset paths from 'dataloader/data_config.yaml'

$ python train.py 

OR, give dataset path as cl args

$ python train.py --train_img_dir <path to training images directory> --train_annotation_file <path to annotation file>

Training setting :

  • Model config : Since MS-COCO is relatively small dataset, I used ResNet50 as image encoder instead of Vision Transformer. Further, I also reduced the number of transformer layers to 6 in text encoder. Detailed model config is here : model_config.yaml

  • Batch size : 256. I trained using 4 GTX1080 GPUs (64 batch size per gpu).

  • Optimizer : Adam optimizer with weight decay.

  • Scheduler : Cosine Scheduler with warmup for first 20% of gradient update steps. Detailed training config is here : train_config.yaml

  • Temperature parameter clipping : Added temperature clipping as mentioned in the paper for training stability. The learnable temperature parameter is clipped to prevent scaling the logits by more than 100.

Zero-shot classification :

For zero-shot classification, first all class names are converted into sentences using templates (like "a photo of a {class name}") and their text embeddings are computed using CLIP. Then to classify an image, first image embedding is computed using CLIP and then its cosine similarity with all the class sentences embeddings is computed to predict the class with the highest cosine similarity.

Zero-shot demo :

Trained weights :

  • Download trained checkpoint from google drive : link
  • Or use gdown to download it :
    # first install gdown
    $ pip install gdown
    
    # then download trained weights at 'saved_checkpoints/trained_checkpoint.pt'
    $ mkdir saved_checkpoints
    $ gdown --id 1BVEY4WeFmQb3wv0A6RaLyVjnc7qmChH2 -O saved_checkpoints/trained_checkpoint.pt  
    

To classify image(s) into CIFAR100 classes, run the following

# to classify a single image
$ python zero_shot_demo.py --checkpoint_path <path_to_trained_checkpoint.pt> --img_path <path_to_img.jpg> --show_prediction

# to classify all images images in a directory
$ python zero_shot_demo.py --checkpoint_path <path_to_trained_checkpoint.pt> --img_dir <path_to_img_directory> --show_prediction

# --show_prediction flag is to save a prediction figure with class probabilities
# NOTE : Please put even number of images in img_directory to get a nice prediction figure

Example to run zero-shot demo:

# first put trained weights at saved_checkpoints/trained_checkpoint.pt 

# for single image
$ python zero_shot_demo.py --checkpoint_path saved_checkpoints/trained_checkpoint.pt --img_path test_images/bicycle.jpeg --show_prediction

# for an image directory
$ python zero_shot_demo.py --checkpoint_path saved_checkpoints/trained_checkpoint.pt --img_dir test_images --show_prediction

# view prediction figure in "demo_output" directory

Zero-shot evaluation on vision datasets + observations :

For evaluation results and instructions on how to run eval code, check this : Observations and Eval results

clip-training's People

Contributors

revantteotia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

clip-training's Issues

Load ViT Model in clip-training

Hello, @revantteotia ~
I would like to use your great code to train my dataset, but I want to load ViT-B/32 Model. I see you have finished the code of VisualTransformer, but I have a little question about the parameter--vision_patch_size.
I read other ViT model code which set patch_size = 32 when the image resolution is 224, so I want to ask if the 32 is right in model_config.yaml if my image_resolution is 224.

Logging capabilities for Training & Evaluation

Hi, Thanks for the great codebase for training CLIP from scratch. It really helps with understanding how it works.

I was wondering if there are any facilities to log and plot the training and evaluation losses? Something like Tensorboard or Weights&Biases?

Thanks again.

License

Hi.

Considering that the original CLIP source code has no training code, this repo can be a very valuable resource.

Would you consider adding a license? (I would suggest the MIT license, just because it's the license of CLIP).

COCO Image Retrieval

Hi, First of all, Thanks for the great code base, it really helped me understand CLIP.

I was wondering if there is a COCO Image Retrieval code available to test the capability of the extracted image features to find similar images? Writing the code would be easy but finding a standard code that everyone uses to test the Image Retrieval capacity of CLIP's image features would be really good.

The benchmarks at PapersWithCode each do it in their own way and I was wondering if there was a better and easier way to do it.

Thanks again

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.