Hi Zhitong
Thanks a lot for releasing the code. I am trying to train the model on LIDC.
After start, an error occurs immediately:
/storage/homefs/lz20w714/anaconda3/envs/mose/lib/python3.8/site-packages/scipy/init.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
Traceback (most recent call last):
File "main.py", line 50, in
model.train(data)
File "/storage/homefs/lz20w714/git/mose-auseg/engine.py", line 52, in train
self.validate(data)
File "/storage/homefs/lz20w714/git/mose-auseg/engine.py", line 112, in validate
metrics,prediction,prob = self.net.forward(patch_arrangement, masks_arrangement, prob_gt, val = True)
File "/storage/homefs/lz20w714/anaconda3/envs/mose/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward
return self.module(*inputs[0], **kwargs[0])
File "/storage/homefs/lz20w714/anaconda3/envs/mose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/storage/homefs/lz20w714/git/mose-auseg/models/MoSE.py", line 146, in forward
metric = metrics.cal_metrics_batch((pred.argmax(2)).long(), (label).long(), sample_probs, prob_gt,
File "/storage/homefs/lz20w714/git/mose-auseg/utils/metrics.py", line 102, in cal_metrics_batch
d_sy = get_cost_matrix(sample_arr, gt_arr, M, N, d_sy, label_range=label_range)
File "/storage/homefs/lz20w714/git/mose-auseg/utils/metrics.py", line 46, in get_cost_matrix
cij = (dist_fct(sample_arr[:, i, ...], gt_arr[:, j, ...], label_range=label_range))
File "/storage/homefs/lz20w714/git/mose-auseg/utils/metrics.py", line 18, in iou_dist
intersection = torch.sum(m1 * m2, dim=[-1, -2]) # keep batch and class dimension
RuntimeError: The size of tensor a (128) must match the size of tensor b (4) at non-singleton dimension 2
Thanks in advance for your help.