Coder Social home page Coder Social logo

ouenal / scribblekitti Goto Github PK

View Code? Open in Web Editor NEW
141.0 11.0 17.0 11.95 MB

Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Python 100.00%
lidar semantic-segmentation weakly-supervised-segmentation scribble-supervised-segmentation scribblekitti

scribblekitti's People

Contributors

ouenal avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

scribblekitti's Issues

Where can I find the model files for MinkowskiNet and SPVCNN?

Thanks for your amazing work.

In your paper, you mentioned that the algorithm was tested using Cylinder3D, MinkowskiNet, and SPVCNN, but I only found the code for the Cylinder3D. Could you please provide the code for the other two segmentation models?

Thank you.

Consistency loss

Hi, thanks for your great work! I have a question about the consistency loss between the teacher and the student on the unlabeled points. In the code you used the KL-divergence, but in the paper (formula 3) it's something different. For me formula 3 looks like a soft version of cross-entropy, but the minus sign is missing. Or should it be the KL-divergence (https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) and you forgot some part of it? Or am I missing something?

Segmentation Performance with Partially Annotated Data

Thank you for open-sourcing your annotated data and code!

Regarding Table 2 in your paper, I have a question about the segmentation performance of Cylinder3D and SparseConv-UNet (Ref [18] in your reference).

The results under the 10% frame split are 46.8% for Cylinder3D and 43.9% for SparseConv-UNet. I have recently run experiments on Cylinder3D with the same number of labeled training frames (1913 out of 19130) and got much higher results (55%+). I am using the latest version of Cylinder3D from here. I use exact configurations provided by the authors except for init_size. I replaced it with 16 (originally set as 32). I would like to know how you exactly implement this and what is the potential cause for such a huge performance difference. Thanks!

Batch size bigger than 1

Thanks so much for your excellent work and code. I've a question about the dataloader code. Is it possible to set the batch size bigger than 1 (e.g., 4, 8 or 16). When I tried to set a bigger batch_size in training.yaml, the code went wrong with the error "RuntimeError: stack expects each tensor to be equal size, but got [124266, 3] at entry 0 and [112695, 3] at entry 1".

# in `training.yaml`
train_dataloader:
  batch_size: 4       # default is 1
  shuffle: True
  num_workers: 4

Thanks in advance!

dataloader issue

Hi,

Thanks for sharing your work! Right now i am running scribblekitti on my own data and experience the following issue:

When i run step 1 with train.py, the progress bar shows exact twice the amount of frames than there are in the folders for training.
But shouldn't it display the sum of frames inside the training and validation folder?
Also it seems the model is validating on the labels inside the folders for training, not on the labels inside the validation folder.

In dataloader/semantickitti.py the 'Baseline' class has the method 'load_file_paths', which is called during initialization of the 'Baseline' class, as well as during the initialization of the 'PLSCylindricalMT'.

Could be the issue that the method is called twice or am i missing a point here?

Baseline:

    self.split, self.config = split, config
    self.root_dir = self.config['root_dir']
    assert(os.path.isdir(self.root_dir))
    label_directory = 'scribbles' if 'label_directory' not in config.keys() \
                                  else config['label_directory']
    self.label_directory =  label_directory if split == 'train' else 'labels'
    self.load_file_paths(split, self.label_directory)

PLSCylindricalMT:

    self.load_file_paths('train', self.label_directory)
    self.nclasses = nclasses
    self.bin_sizes = self.config['bin_size']

Sparse annotations for other datasets

Thanks for your great work! It really helps me a lot.
Do you have any plans to provide sparse annotations for other datasets such as nuscenes or semanticposs?

Training speed

Thanks for your amazing work, and I'm care about the time of training consuming.

My GPU is Tesla V100 32G(single). Under your training settings, each iteration consumes around 1.5-2s in STEP 1, and the training time for each epoch is close to 15-16h, repeating 75 epochs for training, it seems time-consuming.

Would you like to share your device setting and the details of time consuming during training(like iteration time and the whole training pipeline)?

Error while creating Pseudo-Labels

Hi,
I would like to get Scribble Kitti running on my own data. For this I have already created my own scribble labels. Because I don't have point labels, I downloaded the checkpoint of the first step and ran save.py with it. That worked smoothly. But when I run the crb with the generated h5-file I get the following error:

  Determining global threshold k^(c,r)...
  0%|                                                                                                                                                                                                | 0/19130 
  [00:00<?, ?it/s]
  Traceback (most recent call last):
  File "crb.py", line 65, in <module>
    mask = pred[bin_mask] == j
  IndexError: boolean index did not match indexed array along dimension 0; dimension is 1 but corresponding boolean dimension is 124668

At first I thought that maybe it is because of my data. But the error also occurs when I use your scribbles. I have not changed anything in the code. Here is a link to the h5-file i created with your labels after downloading the step1-checkpoint. I just thought I would ask. Maybe it's something trivial that I haven't noticed.

Thank's in advance!

Best Regards
Leon

Size Mismatch

When trying to evaluate on my own files, without labels, i get the following error message:

Traceback (most recent call last):
  File "evaluate.py", line 38, in <module>
    model = LightningEvaluator.load_from_checkpoint(args.ckpt_path, config=config)
  File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 159, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LightningEvaluator:
	size mismatch for student.unet.contextBlock.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 16, 32]) from checkpoint, the shape in current model is torch.Size([32, 1, 3, 3, 16]).
	size mismatch for student.unet.contextBlock.conv1_2.weight: copying a param with shape torch.Size([3, 1, 3, 32, 32]) from checkpoint, the shape in current model is torch.Size([32, 3, 1, 3, 32]).
	size mismatch for student.unet.contextBlock.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 16, 32]) from checkpoint, the shape in current model is torch.Size([32, 3, 1, 3, 16]).
	size mismatch for student.unet.contextBlock.conv3.weight: copying a param with shape torch.Size([1, 3, 3, 32, 32]) from checkpoint, the shape in current model is torch.Size([32, 1, 3, 3, 32]).
	size mismatch for student.unet.resBlock0.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 32, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 32]).
	size mismatch for student.unet.resBlock0.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 64]).
	size mismatch for student.unet.resBlock0.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 32, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 32]).
	size mismatch for student.unet.resBlock0.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 64]).
	size mismatch for student.unet.resBlock0.pool.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for student.unet.resBlock1.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 64, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 64]).
	size mismatch for student.unet.resBlock1.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 128]).
	size mismatch for student.unet.resBlock1.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 64, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 64]).
	size mismatch for student.unet.resBlock1.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 128]).
	size mismatch for student.unet.resBlock1.pool.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for student.unet.resBlock2.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 128, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 128]).
	size mismatch for student.unet.resBlock2.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 256]).
	size mismatch for student.unet.resBlock2.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 128, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 128]).
	size mismatch for student.unet.resBlock2.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 256]).
	size mismatch for student.unet.resBlock2.pool.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for student.unet.resBlock3.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 256, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 256]).
	size mismatch for student.unet.resBlock3.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 512]).
	size mismatch for student.unet.resBlock3.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 256, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 256]).
	size mismatch for student.unet.resBlock3.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 512]).
	size mismatch for student.unet.resBlock3.pool.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for student.unet.upBlock0.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for student.unet.upBlock0.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 512]).
	size mismatch for student.unet.upBlock0.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 512]).
	size mismatch for student.unet.upBlock0.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for student.unet.upBlock0.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for student.unet.upBlock1.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 512, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 512]).
	size mismatch for student.unet.upBlock1.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 256]).
	size mismatch for student.unet.upBlock1.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 256]).
	size mismatch for student.unet.upBlock1.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for student.unet.upBlock1.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for student.unet.upBlock2.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 256, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 256]).
	size mismatch for student.unet.upBlock2.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 128]).
	size mismatch for student.unet.upBlock2.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 128]).
	size mismatch for student.unet.upBlock2.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for student.unet.upBlock2.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for student.unet.upBlock3.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 128, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 128]).
	size mismatch for student.unet.upBlock3.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 64]).
	size mismatch for student.unet.upBlock3.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 64]).
	size mismatch for student.unet.upBlock3.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for student.unet.upBlock3.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for student.unet.reconBlock.conv1.weight: copying a param with shape torch.Size([3, 1, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 1, 64]).
	size mismatch for student.unet.reconBlock.conv1_2.weight: copying a param with shape torch.Size([1, 3, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 1, 64]).
	size mismatch for student.unet.reconBlock.conv1_3.weight: copying a param with shape torch.Size([1, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 1, 3, 64]).
	size mismatch for student.unet.logits.weight: copying a param with shape torch.Size([3, 3, 3, 128, 20]) from checkpoint, the shape in current model is torch.Size([20, 3, 3, 3, 128]).
	size mismatch for teacher.unet.contextBlock.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 16, 32]) from checkpoint, the shape in current model is torch.Size([32, 1, 3, 3, 16]).
	size mismatch for teacher.unet.contextBlock.conv1_2.weight: copying a param with shape torch.Size([3, 1, 3, 32, 32]) from checkpoint, the shape in current model is torch.Size([32, 3, 1, 3, 32]).
	size mismatch for teacher.unet.contextBlock.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 16, 32]) from checkpoint, the shape in current model is torch.Size([32, 3, 1, 3, 16]).
	size mismatch for teacher.unet.contextBlock.conv3.weight: copying a param with shape torch.Size([1, 3, 3, 32, 32]) from checkpoint, the shape in current model is torch.Size([32, 1, 3, 3, 32]).
	size mismatch for teacher.unet.resBlock0.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 32, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 32]).
	size mismatch for teacher.unet.resBlock0.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 64]).
	size mismatch for teacher.unet.resBlock0.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 32, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 32]).
	size mismatch for teacher.unet.resBlock0.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 64]).
	size mismatch for teacher.unet.resBlock0.pool.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for teacher.unet.resBlock1.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 64, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 64]).
	size mismatch for teacher.unet.resBlock1.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 128]).
	size mismatch for teacher.unet.resBlock1.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 64, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 64]).
	size mismatch for teacher.unet.resBlock1.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 128]).
	size mismatch for teacher.unet.resBlock1.pool.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for teacher.unet.resBlock2.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 128, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 128]).
	size mismatch for teacher.unet.resBlock2.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 256]).
	size mismatch for teacher.unet.resBlock2.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 128, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 128]).
	size mismatch for teacher.unet.resBlock2.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 256]).
	size mismatch for teacher.unet.resBlock2.pool.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for teacher.unet.resBlock3.conv1.weight: copying a param with shape torch.Size([3, 1, 3, 256, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 256]).
	size mismatch for teacher.unet.resBlock3.conv1_2.weight: copying a param with shape torch.Size([1, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 512]).
	size mismatch for teacher.unet.resBlock3.conv2.weight: copying a param with shape torch.Size([1, 3, 3, 256, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 256]).
	size mismatch for teacher.unet.resBlock3.conv3.weight: copying a param with shape torch.Size([3, 1, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 512]).
	size mismatch for teacher.unet.resBlock3.pool.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock0.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock0.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 1, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock0.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 1, 3, 512]).
	size mismatch for teacher.unet.upBlock0.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock0.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 512, 512]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock1.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 512, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 512]).
	size mismatch for teacher.unet.upBlock1.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3, 256]).
	size mismatch for teacher.unet.upBlock1.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 1, 3, 256]).
	size mismatch for teacher.unet.upBlock1.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for teacher.unet.upBlock1.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3, 256]).
	size mismatch for teacher.unet.upBlock2.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 256, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 256]).
	size mismatch for teacher.unet.upBlock2.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3, 128]).
	size mismatch for teacher.unet.upBlock2.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 1, 3, 128]).
	size mismatch for teacher.unet.upBlock2.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for teacher.unet.upBlock2.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 128, 128]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3, 128]).
	size mismatch for teacher.unet.upBlock3.trans_dilao.weight: copying a param with shape torch.Size([3, 3, 3, 128, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 128]).
	size mismatch for teacher.unet.upBlock3.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3, 64]).
	size mismatch for teacher.unet.upBlock3.conv2.weight: copying a param with shape torch.Size([3, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 3, 64]).
	size mismatch for teacher.unet.upBlock3.conv3.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for teacher.unet.upBlock3.up_subm.weight: copying a param with shape torch.Size([3, 3, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 64]).
	size mismatch for teacher.unet.reconBlock.conv1.weight: copying a param with shape torch.Size([3, 1, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 3, 1, 1, 64]).
	size mismatch for teacher.unet.reconBlock.conv1_2.weight: copying a param with shape torch.Size([1, 3, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 1, 64]).
	size mismatch for teacher.unet.reconBlock.conv1_3.weight: copying a param with shape torch.Size([1, 1, 3, 64, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 1, 3, 64]).
	size mismatch for teacher.unet.logits.weight: copying a param with shape torch.Size([3, 3, 3, 128, 20]) from checkpoint, the shape in current model is torch.Size([20, 3, 3, 3, 128]).

It seems like the dimensions of the array are mixed up?

Test on own data

Hello,

how can i test one of your models on my own data?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.