Coder Social home page Coder Social logo

Comments (16)

wmabebe avatar wmabebe commented on May 29, 2024 2

from pytorch-model-compare.

zjcqn avatar zjcqn commented on May 29, 2024 1

I found a workaround for this problem by using the model_layers argument. I temporarily removed the assert statement and plotted the figure, which showed that only some layers had nan values. Then I excluded those layers from the computation.

from pytorch-model-compare.

Yash-10 avatar Yash-10 commented on May 29, 2024

Not sure if this helps: for my case, I increased the batch_size from 1 to 16, and the error went away. Maybe you can try it? From 16 to, say, 32 or 64.

from pytorch-model-compare.

bryanbocao avatar bryanbocao commented on May 29, 2024

Hi @Yash-10, thanks for your reply! I used batch_size = 16 in the previous example.

from pytorch-model-compare.

Yash-10 avatar Yash-10 commented on May 29, 2024

I am sorry; I meant that for my own application (different from yours), I increased the batch size and the error disappeared. Since you used batch_size = 16, I wondered if increasing it to 32/64 might remove the error.

from pytorch-model-compare.

bryanbocao avatar bryanbocao commented on May 29, 2024

No worries! Thanks for your help! Running with batch_size = 32, 64 and 128 now. Will post the results when finished.

BTW, it seems to take hours to finish. I am using RTX 3090 and the above scripts take 19, 20 and 21 GB GPU memory. Hope the time and memory spent are normal here.

from pytorch-model-compare.

bryanbocao avatar bryanbocao commented on May 29, 2024

@Yash-10 Sorr, I tried batch size 32, 64 and 128 but still got the same results:

python3 eff_b0b2_compare.py 
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=EfficientNet_B0_Weights.IMAGENET1K_V1`. You can also use `weights=EfficientNet_B0_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=EfficientNet_B2_Weights.IMAGENET1K_V1`. You can also use `weights=EfficientNet_B2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Files already downloaded and verified
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:62: UserWarning: Model 1 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model1_layers' parameter. Your CPU/GPU will thank you :)
  warn("Model 1 seems to have a lot of layers. " \
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:69: UserWarning: Model 2 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model2_layers' parameter. Your CPU/GPU will thank you :)
  warn("Model 2 seems to have a lot of layers. " \
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:145: UserWarning: Dataloader for Model 2 is not given. Using the same dataloader for both models.
  warn("Dataloader for Model 2 is not given. Using the same dataloader for both models.")
| Comparing features |: 100%|█| 79/79 [15:07:33<00:00, 689
Traceback (most recent call last):
  File "eff_b0b2_compare.py", line 45, in <module>
    cka.compare(dataloader)
  File "/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py", line 183, in compare
    assert not torch.isnan(self.hsic_matrix).any(), "HSIC computation resulted in NANs"
AssertionError: HSIC computation resulted in NANs

from pytorch-model-compare.

bryanbocao avatar bryanbocao commented on May 29, 2024

@Yash-10 Sorry, still got the nan errors with batch_size = 32, 64 and 128.

PyTorch 1.13.1+cu117
NVIDIA-SMI 470.161.03 Driver Version: 470.161.03 CUDA Version: 11.4

from pytorch-model-compare.

wmabebe avatar wmabebe commented on May 29, 2024

I'm not sure what's causing it but the problem seems to occur whenever using efficientnet, mobilenet or custom implemented resnets. It works fine when using the torch.models.resnet models.

from pytorch-model-compare.

ImmortalSdm avatar ImmortalSdm commented on May 29, 2024

I find that sometimes HSIC computation will cause negative, which may cause the final sqrt computation to get NaN. I've tried L1, L2 norm, but still meet the error.

from pytorch-model-compare.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.