Comments (8)
Running some experiments with this on an internal dataset using the Big Transfer ResNetV2 architecture. One of the other reasons I think GroupNorm might be promising was its transfer learning performance as demonstrated in the Big Transfer paper. Even though sparse normalization counteracts some of the distribution shift, there is going to be a higher degree of feature interaction with unmasked input. Group norm could possibly be more robust against this than batch norm for pretraining -> training. If it shows promise I will do an ImageNet run and post the results here.
from spark.
@csvance Yeah the sync batch norm and a big enough batch size are important for BN stability. We used 32 Tesla A100s, bs=128 per GPU (so total bs will be 4096) in most of the time, and didn't use the gradient accumulation. I think bs=64 is too small for BN, and 4x128=512 can be better.
from spark.
@csvance very insightful thinking. I've also heard before that using a 3D sparse convolutional backbone network can lead to insufficient global information interaction in 3D point cloud perception (actually the interaction only occurs within connected components).
Yeah GroupNorm, LayerNorm or some attention-like operator can alleviate this problem. It's a promising direction to explore.
from spark.
Hello, I know the paper says you use a batch size of 4096, but was curious how many GPU that was split between? Having some stability issues and I suspect it has to do with effective batch size for batch norm in the decoder. Previously I was using a batch size of 64 and accumulating gradient 64 times on single RTX 3090 24GB to get 4096. Now I have access to 4x A6000 48GB and am trying batch size 128 + gradient accumulation 8 to get 4096 and using sync batch norm same as SparK decoder. Hoping that having a much higher effective batch size for batch norm in decoder will be the key to stop training from diverging.
from spark.
Yeah I'm definitely seeing a big difference between my new and old setup. There is still some instability with 4*128 effective batch size for sync batch norm, but things converge much better than I have seen before. It looks like BatchNorm + large batch size is crucial for the decoder here, I have tried decoder with GroupNorm and convergence is significantly worse without any improvement to stability.
Just as an experiment I'm running with an image size of 128x128 and using a batch size of 512 per GPU giving me 2048 sync BN batch size (accumulate gradient twice to get 4096 for optimizer step). Will be interesting to see if there is still issues with constant gradient explosion. Here is what the divergence looks like in the loss curve, it pretty much always happens when I reach a certain loss around 0.3 MSE or so. Doesn't matter even when I fine tuning gradient clipping, learning rate etc, it's like the loss landscape is extremely sharp / unstable without sufficient batch size for batch norm.
from spark.
I was able to get SparK to converge with LayerNorm in the decoder instead of BatchNorm! I had forgot to enable decoupled weight decay with the optimizer I was using, which was the source of the divergences (too much weight decay relative to learning rate). Still during training there are some times where the loss spikes a bit, but its not extreme and starts to decrease again to a better minima.
I have no doubt that BatchNorm will converge faster still, but using LayerNorm in decoder could be a good option for those who do not have access to a huge number of GPU.
from spark.
@csvance Happy to hear that! and thanks for your effort. Substituting BN with LN or GN (groupnorm) is indeed a valuable try, and I guess that BN isn't always essential. We initially adopted BN just because UNet used it, but I believe LN or GN could effectively replace BN without a lot of performance drop, and yes, this could be particularly beneficial for those with limited GPU resources.
from spark.
For using SparK with backscatter X-ray images, I found it was good to use a larger epsilon for tile normalization and also normalize x_hat tiles. Reason for this is there is many tiles which are mostly background since alot of X-ray are taller than they are wide and often have large segments of noisy background. This made the learned representation transfer better for downstream problems. Without the large epsilon, training is unstable when normalizing x_hat tiles at the start of training which seems to negatively impact the learned representation. I suspect normalizing x_hat is a useful inductive bias, but I havn't tried any of this with ImageNet yet.
Until now I have been working with a relatively small subset of my dataset, roughly ~100k. Going to ramp things up several order of magnitude now. Results on downstream tasks are very promising even with such few images. Downstream is already close to ImageNet21K transfer performance.
from spark.
Related Issues (20)
- 对比convnextv2 HOT 1
- reducing pre-training to 200 epochs HOT 9
- Tutorial for finetune on my own dataset HOT 1
- Are there any plans to make a port to tensorflow and Keras? HOT 1
- ImageNet finetuning exploding HOT 9
- there is no requirements.txt file. HOT 1
- SparK for semantic segmentation HOT 3
- Resuming ImageNet fine-tuning HOT 2
- About sparse convolution HOT 4
- How to transfer this method to 3D situation. HOT 1
- ConvNext B for reconstruct images HOT 3
- recommend a great library designed for sparse tensors HOT 1
- Can SparK be used for few-shot learning? HOT 2
- SparseBatchNorm2d can not mask correctly ? HOT 3
- A Code Issue About “pretrain/main.py” HOT 2
- ConvNext implementation performance HOT 4
- Increasing batch size HOT 1
- Necessity of Mask Tokens
- The versions of mmdet and mmcv?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from spark.