Coder Social home page Coder Social logo

byol's Introduction

BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning

PyTorch implementation of "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" by J.B. Grill et al.

Link to paper

This repository includes a practical implementation of BYOL with:

  • Distributed Data Parallel training
  • Benchmarks on vision datasets (CIFAR-10 / STL-10)
  • Support for PyTorch <= 1.5.0

Open BYOL in Google Colab Notebook

Open In Colab

Results

These are the top-1 accuracy of linear classifiers trained on the (frozen) representations learned by BYOL:

Method Batch size Image size ResNet Projection output dim. Pre-training epochs Optimizer STL-10 CIFAR-10
BYOL + linear eval. 192 224x224 ResNet18 256 100 Adam _ 0.832
Logistic Regression - - - - - - 0.358 0.389

Installation

git clone https://github.com/spijkervet/byol --recurse-submodules -j8
pip3 install -r requirements.txt
python3 main.py

Usage

Using a pre-trained model

The following commands will train a logistic regression model on a pre-trained ResNet18, yielding a top-1 accuracy of 83.2% on CIFAR-10.

curl https://github.com/Spijkervet/BYOL/releases/download/1.0/resnet18-CIFAR10-final.pt -L -O
rm features.p
python3 logistic_regression.py --model_path resnet18-CIFAR10-final.pt

Pre-training

To run pre-training using BYOL with the default arguments (1 node, 1 GPU), use:

python3 main.py

Which is equivalent to:

python3 main.py --nodes 1 --gpus 1

The pre-trained models are saved every n epochs in *.pt files, the final model being model-final.pt

Finetuning

Finetuning a model ('linear evaluation') on top of the pre-trained, frozen ResNet model can be done using:

python3 logistic_regression.py --model_path=./model_final.pt

With model_final.pt being file containing the pre-trained network from the pre-training stage.

Multi-GPU / Multi-node training

Use python3 main.py --gpus 2 to train e.g. on 2 GPU's, and python3 main.py --gpus 2 --nodes 2 to train with 2 GPU's using 2 nodes. See https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html for an excellent explanation.

Arguments

--image_size, default=224, "Image size"
--learning_rate, default=3e-4, "Initial learning rate."
--batch_size, default=42, "Batch size for training."
--num_epochs, default=100, "Number of epochs to train for."
--checkpoint_epochs, default=10, "Number of epochs between checkpoints/summaries."
--dataset_dir, default="./datasets", "Directory where dataset is stored.",
--num_workers, default=8, "Number of data loading workers (caution with nodes!)"
--nodes, default=1, "Number of nodes"
--gpus, default=1, "number of gpus per node"
--nr, default=0, "ranking within the nodes"

byol's People

Contributors

spijkervet avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

byol's Issues

About the number of epochs

Hi there
When I train cifar10 with resnet18, does the accuracy improve if I increase the number of epochs?

How much accuracy do you get if you run about 2000 epochs of experiments?

Confusion of mock image tensor

Thanks for your excellent work!I am a fresher.
The following error occurred when I combined the network structure with my work:
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

BYOL-1.0 CIFAR-10 top-1 acc not reproducible with pre-trained weight

Hello @Spijkervet, thank you for sharing BYOL implementation with CIFAR-10 result which is supposed to be reproducible.

My problem is, the result top-1 accuracy 0.832 doesn't reproduce in my local environment even with using your 1.0/resnet18-CIFAR10-final.pt. As far as I could train, the result shows the same. My result is around 0.43 for both your pre-trained model and my locally trained one.

I'm using torch==1.5.0. I appreciate if you could take a quick look what's wrong.

Log:

$ ~/lab/BYOL/BYOL-1.0$ python logistic_regression.py --model_path ./resnet18-CIFAR10-final.pt
--
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz
99%\|███████████████████████████▋ \| 168206336/170498071 [00:11<00:00, 16383016.35it/s]
Extracting ./datasets/cifar-10-python.tar.gz to ./datasets
Files already downloaded and verified
### Creating features from pre-trained model ###
Step [0/65] Computing features...
Step [5/65] Computing features...
Step [10/65] Computing features...
Step [15/65] Computing features...
Step [20/65] Computing features...
Step [25/65] Computing features...
170500096it [00:30, 16383016.35it/s]
Step [30/65] Computing features...
Step [35/65] Computing features...
Step [40/65] Computing features...
Step [45/65] Computing features...
Step [50/65] Computing features...
Step [55/65] Computing features...
Step [60/65] Computing features...
Features shape (49920, 512)
Step [0/13] Computing features...
Step [5/13] Computing features...
Step [10/13] Computing features...
Features shape (9984, 512)
Epoch [0/300]: Loss/train: 2.257830924987793 Accuracy/train: 0.16569661458333332
Epoch [1/300]: Loss/train: 2.083581094741821 Accuracy/train: 0.24340494791666667
Epoch [2/300]: Loss/train: 2.0132872915267943 Accuracy/train: 0.2754361979166667
Epoch [3/300]: Loss/train: 1.9688853120803833 Accuracy/train: 0.2946354166666667
Epoch [4/300]: Loss/train: 1.9352103281021118 Accuracy/train: 0.31092447916666666
Epoch [5/300]: Loss/train: 1.9084804725646973 Accuracy/train: 0.32227213541666666
Epoch [6/300]: Loss/train: 1.8860071468353272 Accuracy/train: 0.33306640625
  :
Epoch [295/300]: Loss/train: 1.5378710174560546 Accuracy/train: 0.4568359375
Epoch [296/300]: Loss/train: 1.537794508934021 Accuracy/train: 0.4569140625
Epoch [297/300]: Loss/train: 1.5377185583114623 Accuracy/train: 0.45677734375
Epoch [298/300]: Loss/train: 1.5376431322097779 Accuracy/train: 0.456796875
Epoch [299/300]: Loss/train: 1.537568416595459 Accuracy/train: 0.45681640625
### Calculating final testing performance ###
Final test performance: Accuracy/test: 0.423828125

No pretrained model to be logistic regression will also achieve 83+ on cifar-10

before pre-train I save the checkpoint which should be very weak in representation
image

and perform logistic regression as:
python3 logistic_regression.py --model_path=./model-no-train.pt
and the result is
Epoch [296/300]: Loss/train: 0.4051439380645752 Accuracy/train: 0.8583268229166667
Epoch [297/300]: Loss/train: 0.4050535774230957 Accuracy/train: 0.8583854166666667
Epoch [298/300]: Loss/train: 0.4049636995792389 Accuracy/train: 0.8583333333333333
Epoch [299/300]: Loss/train: 0.40487425565719604 Accuracy/train: 0.8583528645833333

Calculating final testing performance

Final test performance: Accuracy/test: 0.8356724330357143

Missing key(s) and size mismatch


RuntimeError Traceback (most recent call last)

in ()
1 # pre-trained model
2 resnet = models.resnet50()
----> 3 resnet.load_state_dict(torch.load(args.model_path, map_location=device))
4 resnet = resnet.to(device)
5

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1050 if len(error_msgs) > 0:
1051 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052 self.class.name, "\n\t".join(error_msgs)))
1053 return _IncompatibleKeys(missing_keys, unexpected_keys)
1054

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn...
size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
size mismatch for layer2.0.downsample.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for layer2.0.downsample.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for layer2.0.downsample.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for layer2.0.downsample.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
size mismatch for layer3.0.downsample.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer3.0.downsample.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer3.0.downsample.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer3.0.downsample.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 1, 1]).
size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048, 1024, 1, 1]).
size mismatch for layer4.0.downsample.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for layer4.0.downsample.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for layer4.0.downsample.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for layer4.0.downsample.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 2048, 1, 1]).
size mismatch for fc.weight: copying a param with shape torch.Size([1000, 512]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).

construct target encoder

I would like to ask target_encoder = self._get_target_encoder() here, isn’t the online encoder copied directly?(target_encoder = copy.deepcopy(self.online_encoder)
I don’t feel that momentum is used. Maybe because I am not familiar with python syntax, but after searching the Internet for a long time, I don’t know where the problem is, so I want to know if anyone can help me.

Result does not match

I use CIFAR10 random model to run linear regression, result is 45%, when I use the model-85, result is still 45% with no improvement. Could you help? I used the default setting. Train with CIFAR10, test finetune with CIFAR10

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.