Coder Social home page Coder Social logo

andrewatanov / simclr-pytorch Goto Github PK

View Code? Open in Web Editor NEW
178.0 5.0 35.0 1.8 MB

PyTorch implementation of SimCLR: supports multi-GPU training and closely reproduces results

License: MIT License

Jupyter Notebook 97.30% Python 2.70%
pytorch self-supervised-learning contrastive-learning pytorch-implementation deep-learning representation-learning

simclr-pytorch's Introduction

SimCLR PyTorch

This is an unofficial repository reproducing results of the paper A Simple Framework for Contrastive Learning of Visual Representations. The implementation supports multi-GPU distributed training on several nodes with PyTorch DistributedDataParallel.

How close are we to the original SimCLR?

The implementation closely reproduces the original ResNet50 results on ImageNet and CIFAR-10.

Dataset Batch Size # Epochs Training GPUs Training time Top-1 accuracy of Linear evaluation (100% labels) Reference
CIFAR-10 1024 1000 2v100 13h 93.44 93.95
ImageNet 512 100 4v100 85h 60.14 60.62
ImageNet 2048 200 16v100 55h 65.58 65.83
ImageNet 2048 600 16v100 170h 67.84 68.71

Pre-trained weights

Try out a pre-trained models Open In Colab

You can download pre-trained weights from here.

To eval the preatrained CIFAR-10 linear model and encoder use the following command:

python train.py --problem eval --eval_only true --iters 1 --arch linear \
--ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000_linear.pth.tar \
--encoder_ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000.pth.tar

To eval the preatrained ImageNet linear model and encoder use the following command:

export IMAGENET_PATH=.../raw-data
python train.py --problem eval --eval_only true --iters 1 --arch linear --data imagenet \
--ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600_linear.pth.tar \
--encoder_ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600.pth.tar

Enviroment Setup

Create a python enviroment with the provided config file and miniconda:

conda env create -f environment.yml
conda activate simclr_pytorch

export IMAGENET_PATH=... # If you have enough RAM using /dev/shm usually accelerates data loading time
export EXMAN_PATH=... # A path to logs

Training

Model training consists of two steps: (1) self-supervised encoder pretraining and (2) classifier learning with the encoder representations. Both steps are done with the train.py script. To see the help for sim-clr/eval problem call the following command: python source/train.py --help --problem sim-clr/eval.

Self-supervised pretraining

CIFAR-10

The config cifar_train_epochs1000_bs1024.yaml contains the parameters to reproduce results for CIFAR-10 dataset. It requires 2 V100 GPUs. The pretraining command is:

python train.py --config configs/cifar_train_epochs1000_bs1024.yaml

ImageNet

The configs imagenet_params_epochs*_bs*.yaml contain the parameters to reproduce results for ImageNet dataset. It requires at 4v100-16v100 GPUs depending on a batch size. The single-node (4 v100 GPUs) pretraining command is:

python train.py --config configs/imagenet_train_epochs100_bs512.yaml

Logs

The logs and the model will be stored at ./logs/exman-train.py/runs/<experiment-id>/. You can access all the experiments from python with exman.Index('./logs/exman-train.py').info().

See how to work with logs Open In Colab

Linear Evaluation

To train a linear classifier on top of the pretrained encoder, run the following command:

python train.py --config configs/cifar_eval.yaml --encoder_ckpt <path-to-encoder>

The above model with batch size 1024 gives 93.5 linear eval test accuracy.

Pretraining with DistributedDataParallel

To train a model with larger batch size on several nodes you need to set --dist ddp flag and specify the following parameters:

  • --dist_address: the address and a port of the main node in the <address>:<port> format
  • --node_rank: 0 for the main node and 1,... for the others.
  • --world_size: the number of nodes.

For example, to train with two nodes you need to run the following command on the main node:

python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address <address>:<port> --node_rank 0 --world_size 2

and on the second node:

python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address <address>:<port> --node_rank 1 --world_size 2

The ImageNet the pretaining on 4 nodes all with 4 GPUs looks as follows:

node1: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address <address>:<port> --node_rank 0
node2: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address <address>:<port> --node_rank 1
node3: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address <address>:<port> --node_rank 2
node4: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address <address>:<port> --node_rank 3

Attribution

Parts of this code are based on the following repositories:v

Acknowledgements

  • This work was supported in part through computational resources of HPC facilities at NRU HSE

simclr-pytorch's People

Contributors

andrewatanov avatar sangh0 avatar senya-ashukha avatar towzeur avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

simclr-pytorch's Issues

Unable to download pre-trained weights

Hello,

I am getting this error when trying to download pre-trained weights with
curl -L $(yadisk-direct https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1) -o pretrained_models.zip

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 174, in _new_conn
    conn = connection.create_connection(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 95, in create_connection
    raise err
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 85, in create_connection
    sock.connect(sa)
TimeoutError: [Errno 110] Connection timed out

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 714, in urlopen
    httplib_response = self._make_request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 403, in _make_request
    self._validate_conn(conn)
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 1053, in _validate_conn
    conn.connect()
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 363, in connect
    self.sock = conn = self._new_conn()
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 179, in _new_conn
    raise ConnectTimeoutError(
urllib3.exceptions.ConnectTimeoutError: (<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/requests/adapters.py", line 487, in send
    resp = conn.urlopen(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 798, in urlopen
    retries = retries.increment(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/retry.py", line 592, in increment
    raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='cloud-api.yandex.net', port=443): Max retries exceeded with url: /v1/disk/public/resources/download?public_key=https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/bin/yadisk-direct", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 23, in main
    print(*[get_real_direct_link(x) for x in args.sharing_link], sep=args.separator)
  File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 23, in <listcomp>
    print(*[get_real_direct_link(x) for x in args.sharing_link], sep=args.separator)
  File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 10, in get_real_direct_link
    pk_request = requests.get(API_ENDPOINT.format(sharing_link))
  File "/usr/local/lib/python3.10/dist-packages/requests/api.py", line 73, in get
    return request("get", url, params=params, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 587, in request
    resp = self.send(prep, **send_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 701, in send
    r = adapter.send(request, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/adapters.py", line 508, in send
    raise ConnectTimeout(e, request=request)
requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='cloud-api.yandex.net', port=443): Max retries exceeded with url: /v1/disk/public/resources/download?public_key=https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)'))

I tried to print out the response of yadisk-direct alone and it seems the files are moved somewhere else (?)

JSON Response: {'message': 'Не удалось найти запрошенный ресурс.', 'description': 'Resource not found.', 'error': 'DiskNotFoundError'}

Is there another way to get the pre-trained model weights? or could you please help me solve this error?

Thanks!

Use more than 4 GPU in linear evaluation

Thanks for good code implementation.
I using 8 gpus in 1 node.
There was not a problem when I used 8 gpus in pretrain.
But when I use 8 gpu in linear evaluation, there is a problem.

TypeError: forward() missing 1 required positional argument: 'x'

How can I solve it?

The problem of the top 1 acc

what means "it may be in std" . Your awesome implementation in pytorch, but I cnnnot understand its big drop~

whats cifar_head?

Hi, @AndrewAtanov . Wondering what cifar_head means. And why is conv1 needed to be added? Can you explain? Thanks!
if cifar_head:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

Learning rate in pretraining on CIFAR10

Hi, thanks for sharing your implementation. It's very helpful.
If I've understood the your implementation right, it seems like you're training with the learning rate initially set to 4.0 on CIFAR10 .
May I ask why you set it to 4.0?
It seems like the authors of simCLR use one out of {0.5, 1.0, 1.5} (not sure which one though), so I'm quite confused.
I would be very thankful if you could explain the reasons behind choosing 4.0 as your learning rate.

Thanks in advance :)

Training log

Hello.
Thanks for sharing your work !

May you provide the training logs for Imagenet training?

Checkpoints and linear evaluation

Hi! Thanks for releasing this repo!

I have a question about training and evaluation. When the model is trained (on cifar10), several checkpoints are saved, so which one should I use to for linear evaluation: checkpoint-48000 or checkpoint?

Also, am I right that this last column shows the final accuracy of linear evaluation that you report?

simclr_repo_q

linear fc layer

When I printed the model, it shows that there is an fc layer after avgpool and before the projection. However, in the forward method of the ResNet, I didn't see fc layer being used. I was wondering where the fc linear layer is used. Thanks!

no image normalization

Hi, thanks for your great work!

I was wondering why there is no channel normalization (transforms.Normalize) for ImageNet and CIFAR?

Reproduced Accuracy

Hi, Andrews,

Thanks so much for sharing your implementation.
Your result is the most close one I've ever seen about SimCLR. I do have a question, how important it is to use lbfgs logistic regression rather than a normal classifer in evaluation? did it make a big change?

Reason for using LARGE_NUMBER

Hello, thanks for the great work! I was wondering what the reason must be behind using the self.LARGE_NUMBER. I understand that it serves to suppress the logits due to self multiplication but is it really necessary given that the labels are negative for them anyways?
Thanks!

Why no upscale in image augmentations?

The SimCLR paper says:

In this work, we sequentially apply three simple augmentations: random
cropping followed by resize back to the original size, random color distortions, and random Gaussian blur

but it seems like the augmentations used in this repository first do a random crop, but do not afterwards resize the crop back to the original size. Why the difference? Am I misunderstanding the SimCLR paper?

Finetuning

Hi, would you mind sharing some example code to finetune the model on a custom dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.