Comments (4)
Or even better, why not compute the full cross correlation matrix (i.e. gather all embedding vectors onto one device and computing the cross correlation there?)
Summing cross correlation matrices (like we do in our code) is equivalent to computing the full cross correlation matrix by gathering all embedding vectors onto one device. They give you exactly the same result.
Think about how to distribute a dot product operation across n machines (computing the cross correlation matrix is basically just a bunch of dot products, one for each pair of features). You could split the vectors into n chunks, compute n smaller dot products (one dot product for each of the n chunks) and sum them to get the final result. Or if you prefer code:
>>> import torch
>>> x = torch.Tensor(8).normal_()
>>> y = torch.Tensor(8).normal_()
>>> torch.allclose(x[:4] @ y[:4] + x[4:] @ y[4:], x @ y)
True
from barlowtwins.
Thanks for the explanation. It's clear to me, now.
from barlowtwins.
Hi, how do we do this on a single GPU? Because torch.distributed seems not to work on a single GPU? I use torch.nn.DataParallel on to run the code on a single GPU.
# empirical cross-correlation matrix
c = self.bn(z1).T @ self.bn(z2)
# sum the cross-correlation matrix between all gpus
c.div_(self.args.batch_size)
torch.distributed.all_reduce(c)
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = off_diagonal(c).pow_(2).sum()
loss = on_diag + self.args.lambd * off_diag
from barlowtwins.
Hi @ltnghia .
Is the problem solved? I think the following code may not be needed, right?:
sum the cross-correlation matrix between all gpus
c.div_(self.args.batch_size)
torch.distributed.all_reduce(c)
from barlowtwins.
Related Issues (20)
- Issue loading checkpoint.pth file HOT 1
- Augmentation Distribution HOT 1
- Will BarlowTwins overfit on the training data?
- When using the own dataset, loss is around 10,000 HOT 8
- About the last normalization layer HOT 1
- Dose we can add distorted image, like Y_a, Y_b, Y_c?
- A question on the BT loss with Batch Norm layers HOT 4
- Start index for each epoch HOT 3
- weight update for step=0,epoch=0 get's missed. HOT 2
- Applications on one-dimensional signal datasets HOT 4
- Where could we find the "reproduced version" of the other SSL methods HOT 1
- Question about Fig. 4 in the paper HOT 4
- Quality of Embeddings
- Possible bug on the loss computation HOT 1
- Error in saving resnet50.pth HOT 2
- efficiency proposal HOT 1
- NaN's introduced during training.
- providing the linear ImageNet classifier weights
- Pre-training model for CIFAR
- Why c.div_(self.args.batch_size) is needed? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from barlowtwins.