Coder Social home page Coder Social logo

junya-chen / flatclr Goto Github PK

View Code? Open in Web Editor NEW
81.0 2.0 8.0 1.08 MB

FlatNCE: A Novel Contrastive Representation Learning Objective

Home Page: https://arxiv.org/pdf/2107.01152.pdf

Python 100.00%
computer-vision representation-learning self-supervised-learning contrastive-learning flatclr simclr

flatclr's Introduction

FlatNCE: A Novel Contrastive Representation Learning Objective

This is the official code repository for the paper Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE.

InfoNCE-based contrastive representation learners, such as SimCLR, have been tremendously successful in recent years. However, these contrastive schemes are notoriously resource demanding, as their effectiveness breaks down with small-batch training (i.e., the log-K curse, whereas K is the batch-size). In this work, we reveal mathematically why contrastive learners fail in the small-batch-size regime, and present a novel simple, non-trivial contrastive objective named FlatNCE, which fixes this issue. Unlike InfoNCE, our FlatNCE no longer explicitly appeals to a discriminative classification goal for contrastive learning. Theoretically, we show FlatNCE is the mathematical dual formulation of InfoNCE, thus bridging the classical literature on energy modeling; and empirically, we demonstrate that, with minimal modification of code, FlatNCE enables immediate performance boost independent of the subject-matter engineering efforts. The significance of this work is furthered by the powerful generalization of contrastive learning techniques, and the introduction of new tools to monitor and diagnose contrastive training. We substantiate our claims with empirical evidence on CIFAR10, ImageNet, and other datasets, where FlatNCE consistently outperforms InfoNCE.

Usage

To start training on the imagenet dataset (or any other), first download and decompress, and place it under ./datasets/imagenet.

Pretraining

We have faster version and slower version (SimCLR implementation) of data augmentation, and faster version only supports for cifar10 and cifar100.

To pretrain the SimCLR on CIFAR-10 with faster version, try the following command:

python main.py --dataset_name=cifar10 --clr=simclr --faster_version=True

To pretrain the FlatCLR on CIFAR-10 with normal version, try the following command:

python main.py --dataset_name=cifar10 --clr=flatclr --faster_version=False

To pretrain the SimCLR on Imagenet, try the following command:

python main.py --dataset_name=imagenet --clr=simclr

To pretrain the FlatCLR on Imagenet, try the following command:

python main.py --dataset_name=imagenet --clr=flatclr

The trained models are saved at: results/{dataset_name}/{batch_size}SimCLR/{date_string}/checkpoint{:04d}.pth.tar'.format(epochs)

Note that learning rate of 0.3 with learning_rate_scaling=linear is equivalent to that of 0.075 with learning_rate_scaling=sqrt when the batch size is 4096. However, using sqrt scaling allows it to train better when smaller batch size is used.

Quick Lookup:

Batch size Linear scaling lr Sqrt scaling lr
128 lr=0.15 lr=0.85
256 lr=0.3 lr=1.20
512 lr=0.6 lr=1.70
1024 lr=1.2 lr=2.40
2048 lr=2.4 lr=3.39
4096 lr=4.8 lr=4.80
8192 lr=9.6 lr=6.79

Resume Pretraining

e.g., To resume a flatclr model from 26 epochs:

python main.py --clr=flatclr --log_dir=results/imagenet/512_FlatCLR/01-06-2021-21-52-10 --train_from=checkpoint_0026.pth.tar

Linear Classification

To train the linear classification on Imagenet, try the following command:

python main.py --dataset_name=imagenet --train_mode=eval --transfer_mode=linear_eval --checkpoint_dir=results/imagenet/512_FlatCLR/01-06-2021-21-52-10

Finetune

To finetune the classifier, try the following command:

python main.py --dataset_name=imagenet --train_mode=eval --transfer_mode=finetune --checkpoint_dir=results/imagenet/512_FlatCLR/01-06-2021-21-52-10

Citation

If you reference or use our method, code or results in your work, please consider citing the FlatNCE paper:

@article{chen2021simpler,
  title={Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE},
  author={Chen, Junya and Gan, Zhe and Li, Xuan and Guo, Qing and Chen, Liqun and Gao, Shuyang and Chung, Tagyoung and Xu, Yi and Zeng, Belinda and Lu, Wenlian and others},
  journal={arXiv preprint arXiv:2107.01152},
  year={2021}
}

flatclr's People

Contributors

junya-chen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

flatclr's Issues

Can you provide a dependencies list

Hi, I was wondering if you could provide a list of dependencies in this repo (e.g., in the form of requirements.txt). I think this should help with the reproducibility of the results in the paper.

can not understand the loss of flatclr.py

Hi,Below line 114 of flatclr.py, no matter what the value of v is, the value of loss_vec will be a vector of all ones. Therefore, in line 119, loss_vec.mean()-1 must be 0. So what is the significance of this item? At the same time, detach is added after the cross_entropy of the item after loss. According to my understanding, this means that the calculated value of this item will not return the gradient. So how does the loss value optimize the network?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.