Coder Social home page Coder Social logo

simsiam's Introduction

SimSiam

This is an Unofficial PyTorch implementation of Exploring Simple Siamese Representation Learning [1]. It is highly inspired by PatrickHua's repository. The key features in this repository are

  1. It supports DistributedDataParallel (DDP) -- not the old fashion DataParallel
  2. It separates the loss function (criteria) from the model. Thus, it is easier to modify the loss function (e.g., add a regularizer) without changing the model.
  3. It automatically switches between the CIFAR-ResNet and the official ResNet based on the training dataset.
  4. The code supports both CIFAR10 and CIFAR100

Requirements

  • Python 3+ [Tested on 3.7]
  • PyTorch 1.x [Tested on 1.7.1]

Usage example

First update constants.py with your dataset dir and checkpoints dir

To train a model python pretrain_main.py

To train a linear classifier python classifier_main.py -- Make sure to update the path of the pretrained model before running.

The CIFAR hyperparameters are already hard coded in the python script. However these hyperparameters can be overridden by providing at least one parameter when running the script (e.g., python pretrain_main.py --arch SimCLR)

By default the code leverage DistributedDataParallel. The are two key parameters to identify the training GPUs

  • --world_size denotes how many GPUs to use
  • --base_gpu denotes what is the start GPU index (0-based).

For example if a machine has four GPUs, and user sets world_size=3 and base_gpu=1, the model will be trained on the 2nd, 3rd and 4th GPUs.

The code logs the training progress in two files

  • train_X.log: Assuming base_gpu=X, the file contains text log of the training progress. There are sample log files inside ./sample_runs dir. These are complete/incomplete runs.
  • exp_name.csv: This csv file keeps track of the KNN accuracy at each --test_interval

To train simCLR model, change the arch parameter to simCLR. This code achieves 89.56% KNN accuracy with simCLR (check ./sample_runs dir).

Quantitative Evaluation on CIFAR10

Paper [1] Ours
KNN Accracy
Linear Classifier 91.8 90.27

Release History

  • 1.0.0
    • First commit on 8 Jan 2021

TODO LIST

  • Revise the readme file
  • Test pretrain_main.py with SimCLR
  • Add the classifier_main.py

References

[1] Chen, Xinlei and He, Kaiming. Representation Learning by Learning to Count. arXiv preprint arXiv:2011.10566

simsiam's People

Contributors

ahmdtaha avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.