Coder Social home page Coder Social logo

guillaumeerhard / supervised_contrastive_loss_pytorch Goto Github PK

View Code? Open in Web Editor NEW
71.0 1.0 11.0 942 KB

Independent implementation of Supervised Contrastive Loss. Straight to the point and beyond

License: MIT License

Python 100.00%
deep-learning contrastive-learning image-classification cifar10

supervised_contrastive_loss_pytorch's Introduction

Supervised Contrastive Loss Pytorch

SimCLR Illustration

This is an independent reimplementation of the Supervised Contrastive Learning paper.
Go here if you want to go to an implementation from one the author in torch and here for the official in tensorflow. The goal of this repository is to provide a straight to the point implementation and experiment to answer specific question.

Results

Accuracy (%) on CIFAR-10

Architecture Cross-entropy Cross-entropy + Auto-augment SupContrast + Auto-augment
ResNet20 91.25 % * 92.71 % 93.51 %

Accuracy (%) on CIFAR-100

Architecture Cross-entropy Cross-entropy + Auto-augment SupContrast + Auto-augment
ResNet20 66.05 66.28 68.42 %

How to use it

Installation

After creating a virtual environment with your software of choice just run

pip install -r requirements.txt

Usage

A simple run of the following command will give you available script option. Default values will help your replicate my results

python train.py -h

Some insights gathered

Some claims from the paper:

Is a contrastive epoch taking 50 % more time than a cross-entropy one ?
Yes this claim seems inline with mine and official implementation

Is the use of heavy data-augmentation necessary ?
Seems like it. A run without hyper parameter tuning and without AutoAugment but with the same data-augmentation as the original ResNet paper yielded a 5 % drop in accuracy compared to the cross-entropy. Although in the paper other data augmentation policies are close behind it contrastive approaches seem to not need sophisticated data augmentation strategies. See original SimCLR paper

Do you need few epochs to train the decoder on the embedding ?
Yes definitely. Only 1-2 epochs of cross-entropy on the embedding gave a model close to the best accuracy. Better configuration were found after tens of epochs but it was usually only better in the 1e-1 accuracy range.

Some findings and personal notes

What is the number of contrastive epoch needed ?
The number of epochs necessary to have a good embedding after the contrastive step is higher than a regular cross-entropy. I did 400/500 epochs while in the official-github the default value is at 1000 epochs and in the paper 700 epochs is mentioned for ILSVRC-2012. For my test with cross entropy it was at most 700 epochs.

Why the loss never reaches zero ?
The supervised contrastive loss defined in the paper will converge to a constant value, which is batch size dependant.
The loss as it is described in the paper is analogous to the Tammes problem where each clusters where projections of a particular class land repel other clusters. Although it is unsolved for such high dimension of 128, an approximate solution over dataset statistics can be easily calculated. This could be computationally intensive when taking in random configurations at each batch but could be avoided with a sampler given back the same labels configuration. I suspect it might be an easy avenue to reduce the number of epochs needed before convergence.

Will this work for very small network ?
This approach seems to work also on small network and is one of the addition of this repo. As you can see ResNet-20 results above where this approach was better than cross entropy and the model is only .3 M parameters. Which is drastically lower than the 20 + M for ResNet-50 on ILSVRC-2012 and the official github.

Would I recommend using this approach for your specific task ? And will I use it ?
One thing that I do like and is the main selling point of this technique is exchanging the boring process of hyper parameter tuning for computation. All result presented here only needed one training attempt. You just need to decrease the learning rate along the way, whereas with cross-entropy I had to rerun the experiment on average 3 times with different learning rate strategy to get the best result shown.
The other thing that seems to emerge from this paper is that it seems that this method is one of the best in a tabula rasa approach. But you can look also in GradAug, CutMix or Bag of tricks. So it might be a great fit when you are dealing with a problem with non standard images i.e no ILSRVC-2012 like dataset available to pretrained on and it is difficult to collect a ton of unlabelled data also. In the case where you can gather a lot of unlabelled data you might have better result with semi-supervised approach like SimCLRv2 or BYOL. But I guess if you are here you know about them.

supervised_contrastive_loss_pytorch's People

Contributors

guillaumeerhard avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

supervised_contrastive_loss_pytorch's Issues

Two different Data augmentations?

Hi,
According to the theory, in stage 1, two different types of augmentations should be applied to a batch of samples.

  1. crop and resize
  2. AutoAugment/randaugment/simaugment.
    It seems the code doesn't have augmentation number (2). Are the mentioned results from only augmentation number (1) which has been applied twice to the same set of samples?

Thanks

Low Val Accuracy on CIFAR 100

Do you mind sharing the parameters you used for training CIFAR 100? For me validation accuracy seems to be saturating at 40% (using resnet-20 though) even with Auto Augment=True. I am interested in knowing the learning rate, batch size, resnet model and temperature you used for CIFAR-100.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.