Comments (4)
Tried
import os
import torch
from torchvision.models import resnet18 # edit
from torchvision.models.detection import fasterrcnn_resnet50_fpn # edit
# from torchvision.models.detection.SSD import ssd300_vgg16 # edit
from torchvision.datasets import CIFAR10 # edit
# import torchvision.datasets as dataset
# print('\n dir(dataset): ', dir(dataset))
# from torchvision.datasets import CocoDetection # edit
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import random
from torch_cka import CKA
# print('\n dir(CocoDetection): ', dir(CocoDetection))
if not os.path.exists('../exps'): os.makedirs('../exps')
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
np.random.seed(0)
random.seed(0)
model1_name, model2_name = 'resnet18', 'F-RCNN' # edit
model1 = resnet18(pretrained=True) # edit
model2 = fasterrcnn_resnet50_fpn(pretrained=True) # edit
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
batch_size = 16 # 64 # 256
dataset = CIFAR10(root='../data/',
train=False,
download=True,
transform=transform)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
worker_init_fn=seed_worker,
generator=g,)
cka = CKA(model1, model2,
model1_name=model1_name, model2_name=model2_name,
device='cuda')
cka.compare(dataloader)
cka.plot_results(save_path="../exps/{}_{}.jpg".format(model1_name, model2_name))
but got
python3 resnet18_FRCNN.py
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Files already downloaded and verified
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py:69: UserWarning: Model 2 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model2_layers' parameter. Your CPU/GPU will thank you :)
warn("Model 2 seems to have a lot of layers. " \
/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py:145: UserWarning: Dataloader for Model 2 is not given. Using the same dataloader for both models.
warn("Dataloader for Model 2 is not given. Using the same dataloader for both models.")
| Comparing features |: 0%| | 0/625 [00:10<?, ?it/s]
Traceback (most recent call last):
File "resnet18_FRCNN.py", line 55, in <module>
cka.compare(dataloader)
File "/home/brcao/Apps/anaconda3/envs/effi/lib/python3.8/site-packages/torch_cka/cka.py", line 172, in compare
Y = feat2.flatten(1)
AttributeError: 'tuple' object has no attribute 'flatten'
Any help would be appreciated @AntixK. Thanks!
from pytorch-model-compare.
@bryanbocao . Have you solved the error AttributeError: 'tuple' object has no attribute 'flatten'.
I got this error when I tried to compare ResNet and ViT.
Is there any solution, please let me know
from pytorch-model-compare.
@ratom unfortunately, not yet. I tried other projects later.
from pytorch-model-compare.
Related Issues (12)
- AssertionError: Input image size (32*32) doesn't match model (224*224). HOT 1
- AssertionError: HSIC computation resulted in NANs HOT 16
- Comparision between ResNet50 and ViT gives error
- ValueError: Input image size (32*32) doesn't match model (224*224). HOT 1
- Bug with num_batches? HOT 1
- getting spurious "HSIC computation resulted in NANs" HOT 6
- The X means what in the formulation of HSIC HOT 2
- Works fine with the whole model but raise "NANs" on selected layers. HOT 3
- "HSIC computation resulted in NANs" HOT 1
- Example to compare datasets HOT 1
- Comparing Two Similar Architectures
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-model-compare.