Coder Social home page Coder Social logo

spaceml-org / self-supervised-learner Goto Github PK

View Code? Open in Web Editor NEW
110.0 10.0 27.0 12.48 MB

Curator can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Home Page: http://spaceml.org/repo/project/605b7b751644770011e850c3/false/

Python 18.30% Shell 0.63% Jupyter Notebook 81.06%
self-supervised-learning ssl curator deep-learning

self-supervised-learner's Introduction

Self-Supervised Learner

The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning. This repo is for you if you have a lot of unlabeled images and a small fraction (if any) of them labeled.

What is Self-Supervised Learning?
Self-supervised learning is a subfield of machine learning focused on developing representations of images without any labels, which is useful for reverse image searching, categorization and filtering of images, especially so when it would be infeasible to have a human manually inspect each individual image. It also has downstream benefits for classification tasks. For instance, training SSL on 100% of your data and finetuning the encoder on the 5% of data that has been labeled significantly outperforms training a model from scratch on 5% of data or transfer learning based approaches typically.

How To Use SSL Curator

Step 1) Self-Supervied Learning (SSL): Training an encoder without labels

  • The first step is to train a self-supervised encoder. Self-supervised learning does not require labels and lets the model learn from purely unlabeled data to build an image encoder. If you want your model to be color invariant, use grey scale images when possible.
python train.py --technique SIMCLR --model imagenet_resnet18 --DATA_PATH myDataFolder/AllImages  --epochs 100 --log_name ssl 

Step 2) Fine tuning: Training a classifier with labels

  • With the self-supervised training done, the encoder is used to initialize a classifier (finetuning). Because the encoder learned from the entire unlabeled dataset previously, the classifier is able to achieve higher classification accuracy than training from scratch or pure transfer learning.
python train.py --technique CLASSIFIER --model ./models/SIMCLR_ssl.ckpt --DATA_PATH myDataFolder/LabeledImages  --epochs 100 --log_name finetune 

Requirements: GPU with CUDA 10+ enabled, requirements.txt

Most Recent Release Update Model Processing Speed
✔️ 1.0.3 Package Documentation Improved Support for SIMSIAM Multi-GPU Training Supported

TL;DR Quick example

Run sh example.sh to see the tool in action on the UC Merced land use dataset.

Arguments to train.py

You use train.py to train an SSL model and classifier. There are multiple arguments available for you to use:

Mandatory Arguments

--model: The architecture of the encoder that is trained. All encoder options can be found in the models/encoders.py. Currently resnet18, imagenet_resnet18, resnet50, imagenet_resnet50 and minicnn are supported. You would call minicnn with a number to represent output embedding size, for example minicnn32

--technique: What type of SSL or classification to do. Options as of 1.0.4 are SIMCLR, SIMSIAM or CLASSIFIER

--log_name: What to call the output model file (prepended with technique). File will be a .ckpt file, for example SIMCLR_mymodel2.ckpt

--DATA_PATH: The path to your data. If your data does not contain a train and val folder, a copy will automatically be created with train & val splits

Your data must be in the following folder structure as per pytorch ImageFolder specifications:

/Dataset
    /Class 1
        Image1.png
        Image2.png
    /Class 2
        Image3.png
        Image4.png

#When your dataset does not have labels yet you still need to nest it one level deep
/Dataset
    /Unlabelled
        Image1.png
        Image2.png

Optional Arguments

--batch_size: batch size to pass to model for training

--epochs: how many epochs to train

--learning_rate: learning rate for the encoder when training

--cpus: how many cpus you have to use for data reading

--gpus: how many gpus you have to use for training

--seed: random seed for reproducibility

-patience: early stopping if validation loss does not go down for (patience) number of epochs

--image_size: 3 x image_size x image_size input fed into encoder

--hidden_dim: hidden dimensions in projection head or classification layer for finetuning, depending on the technique you're using

--OTHER ARGS: each ssl model and classifier have unique arguments specific to that model. For instance, the classifier lets you select a linear_lr argument to specify a different learning rate for the classification layer and the encoder. These optional params can be found by looking at the add_model_specific_args method in each model contained in the models folder.

Optional: To optimize your environment for deep learning, run this repo on the pytorch nvidia docker:

docker pull nvcr.io/nvidia/pytorch:20.12-py3
mkdir docker_folder
docker run --user=root -p 7000-8000:7000-8000/tcp --volume="/etc/group:/etc/group:ro" --volume="/etc/passwd:/etc/passwd:ro" --volume="/etc/shadow:/etc/shadow:ro" --volume="/etc/sudoers.d:/etc/sudoers.d:ro" --gpus all -it --rm -v /docker_folder:/inside_docker nvcr.io/nvidia/pytorch:20.12-py3
apt update
apt install -y libgl1-mesa-glx
#now clone repo inside container, install requirements as usual, login to wandb if you'd like to

How to access models after training in python environment

Both self-supervised models and finetuned models can be accessed and used normally as pl_bolts.LightningModule models. They function the same as a pytorch nn.Module but have added functionality that works with a pytorch lightning Trainer.

For example:

from models import SIMCLR, CLASSIFIER
simclr_model = SIMCLR.SIMCLR.load_from_checkpoint('/content/models/SIMCLR_ssl.ckpt') #Used like a normal pytorch model
classifier_model = CLASSIFIER.CLASSIFIER.load_from_checkpoint('/content/models/CLASSIFIER_ft.ckpt') #Used like a normal pytorch model

Using Your Own Encoder

If you don't want to use the predefined encoders in models/encoders.py, you can pass your own encoder as a .pt file to the --model argument and specify the --embedding_size arg to tell the tool the output shape from the model.

Releases

  • ✔️ (0.7.0) Dali Transforms Added
  • ✔️ (0.8.0) UC Merced Example Added
  • ✔️ (0.9.0) Model Inference with Dali Supported
  • ✔️ (1.0.0) SIMCLR Model Supported
  • ✔️ (1.0.1) GPU Memory Issues Fixed
  • ✔️ (1.0.1) Multi-GPU Training Enabled
  • ✔️ (1.0.2) Package Speed Improvements
  • ✔️ (1.0.3) Support for SimSiam and Code Restructuring
  • 🎫 (1.0.4) Cluster Visualizations for Embeddings
  • 🎫 (1.1.0) Supporting numpy, TFDS datasets
  • 🎫 (1.2.0) Saliency Maps for Embeddings

Citation

If you find Self-Supervised Learner useful in your research, please consider citing the github code for this tool:

@code{
  title={Self-Supervised Learner,
},
  url={https://github.com/spaceml-org/Self-Supervised-Learner},
  year={2021}
}

self-supervised-learner's People

Contributors

abhishekvp avatar ajaykrishnan23 avatar rudyvenguswamy avatar sidgan avatar tarunn2799 avatar walker777007 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

self-supervised-learner's Issues

Unclean code

The code has a lot of redundant old files and unused classes and methods.

Configurable path for saving the model.

Right now the model gets saved in the models dir after training. It would be a good feature if we can give a custom path to save the models after training or fine-tuning.

Support for additional self-supervised techniques

Is there any plan to (flexibly) incorporate other SSL techniques? Barlow Twins does not require large batch sizes (in contrast to SimCLR), and can be easily implemented with a custom loss function using the same backbone as SimSiam (example).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.