andrewatanov / simclr-pytorch Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation of SimCLR: supports multi-GPU training and closely reproduces results
License: MIT License
PyTorch implementation of SimCLR: supports multi-GPU training and closely reproduces results
License: MIT License
The SimCLR paper says:
In this work, we sequentially apply three simple augmentations: random
cropping followed by resize back to the original size, random color distortions, and random Gaussian blur
but it seems like the augmentations used in this repository first do a random crop, but do not afterwards resize the crop back to the original size. Why the difference? Am I misunderstanding the SimCLR paper?
Hello.
Thanks for sharing your work !
May you provide the training logs for Imagenet training?
Hello, thanks for the great work! I was wondering what the reason must be behind using the self.LARGE_NUMBER. I understand that it serves to suppress the logits due to self multiplication but is it really necessary given that the labels are negative for them anyways?
Thanks!
what means "it may be in std" . Your awesome implementation in pytorch, but I cnnnot understand its big drop~
Thanks for good code implementation.
I using 8 gpus in 1 node.
There was not a problem when I used 8 gpus in pretrain.
But when I use 8 gpu in linear evaluation, there is a problem.
TypeError: forward() missing 1 required positional argument: 'x'
How can I solve it?
Hi, @AndrewAtanov . Wondering what cifar_head means. And why is conv1 needed to be added? Can you explain? Thanks!
if cifar_head:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
Hi, would you mind sharing some example code to finetune the model on a custom dataset?
Hi! Thanks for releasing this repo!
I have a question about training and evaluation. When the model is trained (on cifar10), several checkpoints are saved, so which one should I use to for linear evaluation: checkpoint-48000 or checkpoint?
Also, am I right that this last column shows the final accuracy of linear evaluation that you report?
Hi
thanks for your nice job. Can you possibly explain what is the use of adding BatchNorm1dNoBias after the laster linear layer of the projection head (https://github.com/AndrewAtanov/simclr-pytorch/blob/master/models/encoder.py#L40)
Hi, Andrews,
Thanks so much for sharing your implementation.
Your result is the most close one I've ever seen about SimCLR. I do have a question, how important it is to use lbfgs logistic regression rather than a normal classifer in evaluation? did it make a big change?
When I printed the model, it shows that there is an fc layer after avgpool and before the projection. However, in the forward method of the ResNet, I didn't see fc layer being used. I was wondering where the fc linear layer is used. Thanks!
Hi, thanks for your great work!
I was wondering why there is no channel normalization (transforms.Normalize) for ImageNet and CIFAR?
Hello,
I am getting this error when trying to download pre-trained weights with
curl -L $(yadisk-direct https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1) -o pretrained_models.zip
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 174, in _new_conn
conn = connection.create_connection(
File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 95, in create_connection
raise err
File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 85, in create_connection
sock.connect(sa)
TimeoutError: [Errno 110] Connection timed out
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 714, in urlopen
httplib_response = self._make_request(
File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 403, in _make_request
self._validate_conn(conn)
File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 1053, in _validate_conn
conn.connect()
File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 363, in connect
self.sock = conn = self._new_conn()
File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 179, in _new_conn
raise ConnectTimeoutError(
urllib3.exceptions.ConnectTimeoutError: (<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)')
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/requests/adapters.py", line 487, in send
resp = conn.urlopen(
File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 798, in urlopen
retries = retries.increment(
File "/usr/local/lib/python3.10/dist-packages/urllib3/util/retry.py", line 592, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='cloud-api.yandex.net', port=443): Max retries exceeded with url: /v1/disk/public/resources/download?public_key=https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)'))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/bin/yadisk-direct", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 23, in main
print(*[get_real_direct_link(x) for x in args.sharing_link], sep=args.separator)
File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 23, in <listcomp>
print(*[get_real_direct_link(x) for x in args.sharing_link], sep=args.separator)
File "/usr/local/lib/python3.10/dist-packages/wldhx/yadisk_direct/main.py", line 10, in get_real_direct_link
pk_request = requests.get(API_ENDPOINT.format(sharing_link))
File "/usr/local/lib/python3.10/dist-packages/requests/api.py", line 73, in get
return request("get", url, params=params, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/requests/api.py", line 59, in request
return session.request(method=method, url=url, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 587, in request
resp = self.send(prep, **send_kwargs)
File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 701, in send
r = adapter.send(request, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/requests/adapters.py", line 508, in send
raise ConnectTimeout(e, request=request)
requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='cloud-api.yandex.net', port=443): Max retries exceeded with url: /v1/disk/public/resources/download?public_key=https://yadi.sk/d/Sg9uSLfLBMCt5g?w=1 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fc8dd3de500>, 'Connection to cloud-api.yandex.net timed out. (connect timeout=None)'))
I tried to print out the response of yadisk-direct alone and it seems the files are moved somewhere else (?)
JSON Response: {'message': 'Не удалось найти запрошенный ресурс.', 'description': 'Resource not found.', 'error': 'DiskNotFoundError'}
Is there another way to get the pre-trained model weights? or could you please help me solve this error?
Thanks!
Hi, thanks for sharing your implementation. It's very helpful.
If I've understood the your implementation right, it seems like you're training with the learning rate initially set to 4.0 on CIFAR10 .
May I ask why you set it to 4.0?
It seems like the authors of simCLR use one out of {0.5, 1.0, 1.5} (not sure which one though), so I'm quite confused.
I would be very thankful if you could explain the reasons behind choosing 4.0 as your learning rate.
Thanks in advance :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.