Coder Social home page Coder Social logo

Comments (6)

jzbontar avatar jzbontar commented on July 3, 2024 1

I trained BT on ImageNet training set (1.28M images) and computed the BT loss on the ImageNet validation set (50k images) after each training epoch. I used model.train() for both training and validation, because I had some trouble getting track_running_stats=False to work. I trained for 40 epochs using the following command:

python main.py <imagenet_path> --epochs 40 --batch-size 1024 --learning-rate-weights 0.2 --learning-rate-biase 0.005 --weight-decay 1e-06 --lambd 0.005

In my experiments the validation loss does not show the behavior that you observed, but rather follows the training loss pretty well. See figure below:

foo

My best guess is that your training set is small and that BT overfits it.

In your view, is it okay for BT (pre-)training to overfit on the train set or should we stop BT pre-training when BT val loss starts diverging?

I guess it all depends on how you intend to use the representations learnt by BT. If you plan to use it for a downstream task, I would measure the accuracy on that task and stop training based on that.

from barlowtwins.

jzbontar avatar jzbontar commented on July 3, 2024

We currently don't log the BT loss on the validation set. In the original experiments we trained a linear classifier in parallel with BT training and tracked its train and val top-1 accuracy.

I don't know why the train and val BT loss would differ so much. I can run this on ImageNet and report back.

from barlowtwins.

neerajwagh avatar neerajwagh commented on July 3, 2024

Hi @jzbontar, thanks for clarifying your val setup.

It'd be great if you could run a quick experiment (possibly on 10% training data, and a smaller embedding size) to check the BT loss value on the val set in parallel with BT training. When you try this, you may want track_running_stats=False at

self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
(It makes sense for the val BT loss to use batch statistics computed from the val batch instead of using the train batch statistics, right? From what I understand, the last BN layer only exists to make the loss calculation simpler.) This change seems to yield a smoother val loss curve, but the gap between train and val remains unresolved. I'm still trying to debug that part.

In your view, is it okay for BT (pre-)training to overfit on the train set or should we stop BT pre-training when BT val loss starts diverging?

from barlowtwins.

neerajwagh avatar neerajwagh commented on July 3, 2024

Oh, that looks neat!

I'm aiming for task-agnostic and dataset-agnostic embeddings, so I cannot have a backbone that overfits my training set. Even if the downstream task is fixed to a specific one, I'm still worried the embeddings may not work for the same task done using a different downstream dataset. So I guess I need to continue debugging until I see healthy loss curves.

You're right, my dataset is much smaller: train #samples: 140k, val #samples: 54k.

In your val loop for this particular experiment, do you have the same augmentation process as in your training loop?

from barlowtwins.

jzbontar avatar jzbontar commented on July 3, 2024

In your val loop for this particular experiment, do you have the same augmentation process as in your training loop?

Yes, the training and validation datasets have the same augmentation process.

from barlowtwins.

HareshKarnan avatar HareshKarnan commented on July 3, 2024

@neerajwagh were you able to get the BTLoss to reduce on the validation set ?

from barlowtwins.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.