Hi,
Thanks a lot for sharing your code. When I build the 'R50-ViT-B_16' model, there is an error:
vit_name = 'R50-ViT-B_16'
config_vit = CONFIGS[vit_name]
config_vit.n_classes = 2
config_vit.n_skip = 3
if vit_name.find('R50') != -1:
#config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
config_vit.patches.grid = (int(224 / 16), int(224/ 16))
net = ViT_seg(config_vit, img_size=224, num_classes=2)
summary(net, (1,224,224), batch_size=1)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
73
74 # remove these hooks
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
138 def forward(self, *inputs, **kwargs):
139 if not self.device_ids:
--> 140 return self.module(*inputs, **kwargs)
141
142 for t in chain(self.module.parameters(), self.module.buffers()):
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, x)
393 x = x.repeat(1,3,1,1)
394 print("x:", x.shape)
--> 395 x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
396 print("x:{}, attn_weights:{}, features:{}".format(x.shape, attn_weights.shape, features.shape))
397 x = self.decoder(x, features)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, input_ids)
255 def forward(self, input_ids):
256 print("input_ids:", input_ids.shape)
--> 257 embedding_output, features = self.embeddings(input_ids)
258 #print("embedding_output:", embedding_output.shape)
259 encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/Desktop/TransUNet-main/networks/vit_seg_modeling.py in forward(self, x)
157 def forward(self, x):
158 if self.hybrid:
--> 159 x, features = self.hybrid_model(x)
160 else:
161 features = None
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
--> 543 hook_result = hook(self, input, result)
544 if hook_result is not None:
545 result = hook_result
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in hook(module, input, output)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchsummary/torchsummary.py in (.0)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
AttributeError: 'list' object has no attribute 'size'
x, attn_weights, features = self.transformer(x), features is a list?