Coder Social home page Coder Social logo

simsiam-tf's Introduction

SimSiam-TF

Minimal implementation of SimSiam (Exploring Simple Siamese Representation Learning by Xinlei Chen & Kaiming He) in TensorFlow 2. For an introduction, please see this blog post: Self-supervised contrastive learning with SimSiam.

The purpose of this repository is to demonstrate the workflow of SimSiam and NOT to implement it note to note and at the same time I will try not to miss out on the major bits discussed in the paper. For that matter, I'll be using the Flowers dataset.

Following depicts the workflow of SimSiam (taken from the paper) -

The authors have also provided a PyTorch-like psuedocode in the paper (how cool!) -

# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
    x1, x2 = aug(x), aug(x) # random augmentation
    z1, z2 = f(x1), f(x2) # projections, n-by-d
    p1, p2 = h(z1), h(z2) # predictions, n-by-d
    L = D(p1, z2)/2 + D(p2, z1)/2 # loss
    L.backward() # back-propagate
    update(f, h) # SGD update

def D(p, z): # negative cosine similarity
    z = z.detach() # stop gradient
    p = normalize(p, dim=1) # l2-normalize
    z = normalize(z, dim=1) # l2-normalize
    return -(p*z).sum(dim=1).mean()

The authors emphasize the stop_gradient operation that helps the network to avoid collapsing solutions. Further details about this are available in the paper. SimSiam eliminates the need for using large batch sizes, momentum encoders, memory banks, negative samples, etc. that are important components of the modern self-supervised learning frameworks for visual recognition. This makes SimSiam an easily approachable framework for practical problems.

About the notebooks

  • SimSiam_Pre_training.ipynb: Pre-trains a ResNet50 using SimSiam.
  • SimSiam_Evaluation.ipynb: Evaluates (linear evaluation) ResNet50 as pre-trained in SimSiam_Pre_training.ipynb.

Results

Pre-training Schedule Validation Accuracy (Linear Evaluation)
50 epochs 45.64%
75 epochs 44.91%

I think with further hyperparameter-tuning and regularization these scores can be improved.

Supervised training (results are taken from here and here):

Training Type Validation Accuracy (Linear Evaluation)
Supervised ImageNet-trained ResNet50 Features 48.36%
From Scratch Training with ResNet50 63.64%

Observations

The figure below shows the training loss plots from two different pre-training schedules (50 epochs and 75 epochs) -

We see that the loss gets plateaued after 35 epochs. We can experiment with the following components to further improve this -

  • data augmentation pipeline
  • architectures of the two MLP heads
  • learning rate schedule used during pre-training

and so on.

Pre-trained weights

Acknowledgements

Thanks to Connor Shorten's video on the paper that helped in understanding the paper briefly. Thanks to the ML-GDE program for providing GCP Credits that helped in preparing the experiments.

simsiam-tf's People

Contributors

sayakpaul avatar

Stargazers

 avatar Satoshi Terasaki avatar  avatar  avatar Usman Khan avatar Ashish Patel avatar  avatar Jonathan Greer avatar  avatar Behrooz Azarkhalili avatar Nikita Moshkov avatar ze feng avatar cocoinit23 avatar Srinivas Venkatanarayanan avatar Rahul kumar jha avatar Pattaramanee Arsomngern avatar  avatar xiaoxiaoxiaoxuan avatar minjung shin avatar Laura Petrola avatar Dhumsapuram Saikrishna Reddy avatar Manuel Vargas avatar Connor Shorten avatar Mingyue Cheng(程明月) avatar Dense AI avatar  avatar Seder(方进) avatar GAURAV avatar Vectory avatar Avinash avatar Jonne Engelberts avatar Julian Schäfer-Zimmermann avatar Jose Cohenca avatar  avatar NAMIKI Meiko avatar  avatar Sui Libin avatar Marc avatar yukoga avatar Soumik Rakshit avatar David avatar joelxiangnanchen avatar Manas avatar Shyam Sudhakaran avatar CodingMan avatar  avatar peabody124 avatar Lucas Monteiro avatar m1k3 avatar  avatar Christian Staubli avatar Matt Shaffer avatar Erich avatar Luigi Mazzon avatar Myunggeun Ji avatar Abhinav Prakash avatar sixcluster avatar Hafizur Rahman avatar Al Kari avatar Bogdan Mazoure avatar Chintan Trivedi avatar 爱可可-爱生活 avatar Saswat Subhajyoti avatar Amazing-Grace Olutomilayo avatar Jack Burdick avatar Chieh-Yang Chen avatar Sarah Khalil avatar Pranav Maddula avatar Deep C. Patel avatar  avatar Javier Lorenzo Díaz avatar  avatar Matthew avatar toosyou avatar  avatar  avatar Vanildo Vanni avatar Edgar Nova avatar Vadzim Hushchanskou avatar Tobias Scheck avatar Ajay Krishnan avatar Daniel avatar Ajinkya Puar avatar Carlos Lopez avatar Shikhar Bhardwaj avatar Faustino Luna avatar David Macêdo, PhD avatar sharavsambuu avatar Rajkumar Lakshmanamoorthy avatar aaditya prakash avatar Diego Francisco Valenzuela Iturra avatar  avatar Ayushman Kumar avatar Tianyu Hua avatar Michal Lukac avatar Ayush Thakur avatar  avatar

Watchers

James Cloos avatar Luca Di Vita avatar Shikhar Bhardwaj avatar Danish avatar  avatar Matt Shaffer avatar  avatar

simsiam-tf's Issues

Loader shuffled by different seed

Thank you for your prompt contribution.

Maybe this is mistake, SimSiam_Pre_training.ipynb implement ...

Author's,

# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
    x1, x2 = aug(x), aug(x) # random augmentation
    z1, z2 = f(x1), f(x2) # projections, n-by-d
    p1, p2 = h(z1), h(z2) # predictions, n-by-d
    L = D(p1, z2)/2 + D(p2, z1)/2 # loss
    L.backward() # back-propagate
    update(f, h) # SGD update

Your implementation means,

for (x_a, x_b) in (loader_a, loader_b): # load a minibatch x with n samples
    x1, x2 = aug(x_a), aug(x_b) # random augmentation
    z1, z2 = f(x1), f(x2) # projections, n-by-d
    p1, p2 = h(z1), h(z2) # predictions, n-by-d
    L = D(p1, z2)/2 + D(p2, z1)/2 # loss
    L.backward() # back-propagate
    update(f, h) # SGD update

I think the easiest fix is to use shuffle seed...

dataset_one = (
    train_ds
    .shuffle(1024, seed=0)
    .map(custom_augment, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

dataset_two = (
    train_ds
    .shuffle(1024, seed=0)
    .map(custom_augment, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

Regards.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.