Coder Social home page Coder Social logo

Comments (8)

csvance avatar csvance commented on June 11, 2024 1

Running some experiments with this on an internal dataset using the Big Transfer ResNetV2 architecture. One of the other reasons I think GroupNorm might be promising was its transfer learning performance as demonstrated in the Big Transfer paper. Even though sparse normalization counteracts some of the distribution shift, there is going to be a higher degree of feature interaction with unmasked input. Group norm could possibly be more robust against this than batch norm for pretraining -> training. If it shows promise I will do an ImageNet run and post the results here.

from spark.

keyu-tian avatar keyu-tian commented on June 11, 2024 1

@csvance Yeah the sync batch norm and a big enough batch size are important for BN stability. We used 32 Tesla A100s, bs=128 per GPU (so total bs will be 4096) in most of the time, and didn't use the gradient accumulation. I think bs=64 is too small for BN, and 4x128=512 can be better.

from spark.

keyu-tian avatar keyu-tian commented on June 11, 2024

@csvance very insightful thinking. I've also heard before that using a 3D sparse convolutional backbone network can lead to insufficient global information interaction in 3D point cloud perception (actually the interaction only occurs within connected components).

Yeah GroupNorm, LayerNorm or some attention-like operator can alleviate this problem. It's a promising direction to explore.

from spark.

csvance avatar csvance commented on June 11, 2024

Hello, I know the paper says you use a batch size of 4096, but was curious how many GPU that was split between? Having some stability issues and I suspect it has to do with effective batch size for batch norm in the decoder. Previously I was using a batch size of 64 and accumulating gradient 64 times on single RTX 3090 24GB to get 4096. Now I have access to 4x A6000 48GB and am trying batch size 128 + gradient accumulation 8 to get 4096 and using sync batch norm same as SparK decoder. Hoping that having a much higher effective batch size for batch norm in decoder will be the key to stop training from diverging.

from spark.

csvance avatar csvance commented on June 11, 2024

Yeah I'm definitely seeing a big difference between my new and old setup. There is still some instability with 4*128 effective batch size for sync batch norm, but things converge much better than I have seen before. It looks like BatchNorm + large batch size is crucial for the decoder here, I have tried decoder with GroupNorm and convergence is significantly worse without any improvement to stability.

Just as an experiment I'm running with an image size of 128x128 and using a batch size of 512 per GPU giving me 2048 sync BN batch size (accumulate gradient twice to get 4096 for optimizer step). Will be interesting to see if there is still issues with constant gradient explosion. Here is what the divergence looks like in the loss curve, it pretty much always happens when I reach a certain loss around 0.3 MSE or so. Doesn't matter even when I fine tuning gradient clipping, learning rate etc, it's like the loss landscape is extremely sharp / unstable without sufficient batch size for batch norm.

image

from spark.

csvance avatar csvance commented on June 11, 2024

I was able to get SparK to converge with LayerNorm in the decoder instead of BatchNorm! I had forgot to enable decoupled weight decay with the optimizer I was using, which was the source of the divergences (too much weight decay relative to learning rate). Still during training there are some times where the loss spikes a bit, but its not extreme and starts to decrease again to a better minima.

I have no doubt that BatchNorm will converge faster still, but using LayerNorm in decoder could be a good option for those who do not have access to a huge number of GPU.

from spark.

keyu-tian avatar keyu-tian commented on June 11, 2024

@csvance Happy to hear that! and thanks for your effort. Substituting BN with LN or GN (groupnorm) is indeed a valuable try, and I guess that BN isn't always essential. We initially adopted BN just because UNet used it, but I believe LN or GN could effectively replace BN without a lot of performance drop, and yes, this could be particularly beneficial for those with limited GPU resources.

from spark.

csvance avatar csvance commented on June 11, 2024

For using SparK with backscatter X-ray images, I found it was good to use a larger epsilon for tile normalization and also normalize x_hat tiles. Reason for this is there is many tiles which are mostly background since alot of X-ray are taller than they are wide and often have large segments of noisy background. This made the learned representation transfer better for downstream problems. Without the large epsilon, training is unstable when normalizing x_hat tiles at the start of training which seems to negatively impact the learned representation. I suspect normalizing x_hat is a useful inductive bias, but I havn't tried any of this with ImageNet yet.

Until now I have been working with a relatively small subset of my dataset, roughly ~100k. Going to ramp things up several order of magnitude now. Results on downstream tasks are very promising even with such few images. Downstream is already close to ImageNet21K transfer performance.

from spark.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.