Comments (6)
I trained BT on ImageNet training set (1.28M images) and computed the BT loss on the ImageNet validation set (50k images) after each training epoch. I used model.train() for both training and validation, because I had some trouble getting track_running_stats=False
to work. I trained for 40 epochs using the following command:
python main.py <imagenet_path> --epochs 40 --batch-size 1024 --learning-rate-weights 0.2 --learning-rate-biase 0.005 --weight-decay 1e-06 --lambd 0.005
In my experiments the validation loss does not show the behavior that you observed, but rather follows the training loss pretty well. See figure below:
My best guess is that your training set is small and that BT overfits it.
In your view, is it okay for BT (pre-)training to overfit on the train set or should we stop BT pre-training when BT val loss starts diverging?
I guess it all depends on how you intend to use the representations learnt by BT. If you plan to use it for a downstream task, I would measure the accuracy on that task and stop training based on that.
from barlowtwins.
We currently don't log the BT loss on the validation set. In the original experiments we trained a linear classifier in parallel with BT training and tracked its train and val top-1 accuracy.
I don't know why the train and val BT loss would differ so much. I can run this on ImageNet and report back.
from barlowtwins.
Hi @jzbontar, thanks for clarifying your val setup.
It'd be great if you could run a quick experiment (possibly on 10% training data, and a smaller embedding size) to check the BT loss value on the val set in parallel with BT training. When you try this, you may want track_running_stats=False at
Line 205 in 574f589
In your view, is it okay for BT (pre-)training to overfit on the train set or should we stop BT pre-training when BT val loss starts diverging?
from barlowtwins.
Oh, that looks neat!
I'm aiming for task-agnostic and dataset-agnostic embeddings, so I cannot have a backbone that overfits my training set. Even if the downstream task is fixed to a specific one, I'm still worried the embeddings may not work for the same task done using a different downstream dataset. So I guess I need to continue debugging until I see healthy loss curves.
You're right, my dataset is much smaller: train #samples: 140k, val #samples: 54k.
In your val loop for this particular experiment, do you have the same augmentation process as in your training loop?
from barlowtwins.
In your val loop for this particular experiment, do you have the same augmentation process as in your training loop?
Yes, the training and validation datasets have the same augmentation process.
from barlowtwins.
@neerajwagh were you able to get the BTLoss to reduce on the validation set ?
from barlowtwins.
Related Issues (20)
- Issue loading checkpoint.pth file HOT 1
- Augmentation Distribution HOT 1
- Will BarlowTwins overfit on the training data?
- When using the own dataset, loss is around 10,000 HOT 8
- About the last normalization layer HOT 1
- Dose we can add distorted image, like Y_a, Y_b, Y_c?
- A question on the BT loss with Batch Norm layers HOT 4
- Start index for each epoch HOT 3
- weight update for step=0,epoch=0 get's missed. HOT 2
- Applications on one-dimensional signal datasets HOT 4
- Where could we find the "reproduced version" of the other SSL methods HOT 1
- Question about Fig. 4 in the paper HOT 4
- Quality of Embeddings
- Possible bug on the loss computation HOT 1
- Error in saving resnet50.pth HOT 2
- efficiency proposal HOT 1
- NaN's introduced during training.
- providing the linear ImageNet classifier weights
- Pre-training model for CIFAR
- Why c.div_(self.args.batch_size) is needed? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from barlowtwins.