Comments (15)
Thank you so much for this! It is an excellent framework, thanks for your contribution!
from torchdistill.
Hi @AndyFrancesco29 ,
I just released a new version of torchdistill (ver. 0.2.4). You can now upgrade the package in your machine by pip command.
from torchdistill.
Hi @AndyFrancesco29 ,
Definitely, you can use different KD methods for a specific task since the methods in torchdistill are implemented in a model-agnostic way.
Do you have any specific configuration in your mind? (e.g., a set of models, method, task, etc that you're interested in)
If so, I can give you a more concrete advice.
from torchdistill.
Hi @yoshitomo-matsubara ,
Thanks for your reply. I mean for example in RepDistiller repo, it has an input parameter "--distill" to indicate which method to use, like Fitnet or vanilla KD. Is that an option in config file? Thank you very much.
from torchdistill.
Hi @AndyFrancesco29 ,
A main concept in torchdistill is "one configuration (file) -> one method (experiment)".
In other words, you can switch methods by different config files. For example, you can take a look at these sample configs for ImageNet, including vanilla KD and FitNet, and will notice the main differences are loss function and layers to extract their input/output to compute the loss values (e.g., layer3 and layer4 in ResNet-34 and 18 for Attention Transfer sample config).
Basically, (almost) all the hyperparameters and configs are defined in the yaml file, which enables you to easily see differences in methods (by comparing the yaml config files).
You could think --config <config file>
in torchdistill as --distill <method name>
in RepDistiller.
While RepDistiller requires hardcoding everything (models, extracting intermediate input/output, losses, datasets, etc), torchdistill provides more flexibility to allow you design these things in a yaml config file i.e., without writing Python code in many cases.
With torchdistill, you can also easily use models beyond torchvision by specifying PyTorch hub params in the yaml config file e.g., pretrained ResNeSt-50 as a teacher imported from pytorch-image-models.
from torchdistill.
Thanks! I think I get that.
from torchdistill.
Hi @yoshitomo-matsubara ,
Thank you very much for your patience! Right now I am working on SSKD (KD meets self-supervised learning). I have seen you have implemented one with Imagenet, and I want to apply that for CIFAR dataset. In this paper it requires one image to be rotated and then transform. So for one original image there will be four transformed images (rotate 0, 90, 180, 270). I am wondering how you implement this part?
Also in stage one it should only update the SS module for teacher model. In the second stage it updates student network. Could you tell me where you implement these two stages?
Again thank you very much.
from torchdistill.
Hi @AndyFrancesco29 ,
For SSKD, you can refer to this yaml config file and README in this directory.
The image rotation part is implemented as part of DataLoader wrapper here and you will find the SSKDDatasetWrapper
defined in the above yaml config file as well.
For the multi-stage setting, I defined so in the same yaml config file (as part of stage1
and stage2
entries).
At the 1st stage, some part (SS module) of the teacher model is trained as SSWrapper4SSKD
defined in both the yaml config file and Python code while the student model is "empty" for saving training time since the student model is not necessary at the 1st stage.
At the 2nd stage, both teacher and student models use SSWrapper4SSKD
while teacher model is frozen at this point.
If you would like to apply it to CIFAR datasets, you can update the dataset entry with CIFAR as described in other CIFAR sample configs (e.g., this one) and replace the teacher and student models with some of those here (designed for CIFAR datasets with pretrained weights for teacher).
from torchdistill.
Hi @yoshitomo-matsubara ,
Thanks for your reply. I have followed your instruction and write my own config. Right now I am stuck in stage 1. I use resnet18 as student and resnet 32x4 as teacher (implemented as in SSKD/models). According to my understanding, in stage 1 although we only update teacher SS module, we still need to validate student network (in example/image_classification.py line 122). However when I got into evaluate function, and run output = model(image) (line 87), it shows me error:
RuntimeError: CUDA error: out of memory CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
I have set the batch size to 16 and it should not have memory issue (since training SS module for resnet32x4 is successful). Before run code model(image) it also shows several warnings:
Warning: Leaking Caffe2 thread-pool after fork.
For the student network I still use the same configuration as here, and I believe resnet18 is from torchvision so I am not sure why this happens.
I attach my yaml file here if it is convenient.
sskd-resnet18_from_resnet32x4.zip. I store my data in path like "./data/cifar-100-python/train". Thanks a lot if you can help me with that.
PS: The code is run with single GPU. However it shows the warning and create process on different GPU and then fail
from torchdistill.
Hi @AndyFrancesco29 ,
According to my understanding, in stage 1 although we only update teacher SS module, we still need to validate student network (in example/image_classification.py line 122).
Yes, but the validated student model at line 122 is the global student model defined as student_model
in models
, not the local student model (EmptyModule
at stage1
). Thus, the above RuntimeError is not caused by the EmptyModule
.
In you yaml file,
models:
teacher_model:
name: &teacher_model_name 'resnet32x4'
params:
num_classes: 100
experiment: &teacher_experiment !join [*dataset_name, '-', *teacher_model_name]
ckpt: './save/models/resnet32x4_vanilla/ckpt_epoch_240.pth'
student_model:
name: &student_model_name 'resnet18'
params:
num_classes: 100
experiment: &student_experiment !join [*dataset_name, '-', *student_model_name, '_from_', *teacher_model_name]
ckpt: !join ['./cifar/sskd/', *student_experiment, '.pt']
you attempt to define the global teacher and student models as resnet32x4
(should not be in torchdistill by default) and resnet18
(from torchvision by default) respectively.
But, did you edit some of my python code and add resnet32x4
to the registry of model by adding @register_model_func
decorator like this?
Also, resnet18
in torchvision is designed for ImageNet dataset (input shape: 3x224x224) and should not work for CIFAR datasets (input shape: 3x32x32) even with num_classes: 100
due to the smaller input. This is why I suggested here you use some of the models designed for CIFAR datasets.
To test SSKD for CIFAR-100, try this model config instead:
models:
teacher_model:
name: &teacher_model_name 'resnet56'
params:
pretrained: True
num_classes: 100
experiment: &teacher_experiment !join [*dataset_name, '-', *teacher_model_name]
ckpt: './save/models/resnet56_vanilla/ckpt_epoch_240.pth'
student_model:
name: &student_model_name 'resnet20'
params:
pretrained: False
num_classes: 100
experiment: &student_experiment !join [*dataset_name, '-', *student_model_name, '_from_', *teacher_model_name]
ckpt: !join ['./cifar/sskd/', *student_experiment, '.pt']
You will use resnet56
in torchdistill as a teacher and automatically download its pretrained weight as pretrained: True
for training a student model resnet20
in torchdistill.
I think the warning message about Caffe appears because of this, which used to be a normal behavior but recent PyTorch update somehow print such messages.
Next time, provide the command you executed to get the error message.
from torchdistill.
Hi @yoshitomo-matsubara ,
Thanks for your advice.
you attempt to define the global teacher and student models as resnet32x4 (should not be in torchdistill by default) and resnet18 (from torchvision by default) respectively.
But, did you edit some of my python code and add resnet32x4 to the registry of model by adding @register_model_func decorator like this?
Yes I have done that. Stage 1 training is fine the problem happens when it enters the global validation step.
I restart the machine out of memory problem has gone. Also I have tried suggested model, but some problem remains (happens when I tried resnet32x4+resnet18 and resnet56+resnet20):
-
the warning is still here. I use single GPU command to run and it is still there. The interesting part is that it happens in stage1 val and stage 2 train+val. I check GPU usage it shows that it creates annoying process on different cards (PID 11329):
-
In stage 2, in the first epoch it runs successfully (both train and val, with warning above). However when it comes to the second epoch it shows error like this:
`Traceback (most recent call last):
File "/data/xiajingfei/torchdistill/examples/image_classification.py", line 206, in
main(argparser.parse_args())
File "/data/xiajingfei/torchdistill/examples/image_classification.py", line 188, in main
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args)
File "/data/xiajingfei/torchdistill/examples/image_classification.py", line 141, in train
train_one_epoch(training_box, device, epoch, log_freq)
File "/data/xiajingfei/torchdistill/examples/image_classification.py", line 72, in train_one_epoch
loss = training_box(sample_batch, targets, supp_dict)
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/data/xiajingfei/torchdistill/../torchdistill/torchdistill/core/distillation.py", line 324, in forward
total_loss = self.criterion(output_dict, org_loss_dict, targets)
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/data/xiajingfei/torchdistill/../torchdistill/torchdistill/losses/custom.py", line 48, in forward
loss_dict[loss_name] = factor * criterion(student_output_dict, teacher_output_dict, targets)
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/data/xiajingfei/torchdistill/../torchdistill/torchdistill/losses/single.py", line 725, in forward
ce_loss = self.cross_entropy_loss(student_linear_outputs[normal_indices], targets)
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1120, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/deepwisdom/anaconda3/envs/xjf-distiller/lib/python3.9/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (19) to match target batch_size (16).`
If I remove evaluate function in example/image_classification.py line 122, this error disappears. So maybe something wrong happens here? I am still checking that.
This is the config file for resnet56+resnet20:
sskd-resnet20_from_resnet56.zip
Data path is the same as before. The command I used is this:
python3 examples/image_classification.py --config configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/sskd-resnet20_from_resnet56.yaml --log log/ilsvrc2012/yoshitomo-matsubara/rrpr2020/sskd-resnet20_from_resnet56.txt
from torchdistill.
Updates:
I think I find the problem:
In examples/image_classification.py line 77, here you use DataParallel even if one single GPU is assigned for the model, so during validation process it creates other processes on different GPUs. However this action seems that it will change the self.student_io_dict. In the dict during training, it will contain matrix from different cuda device (cuda:0, cuda:1) and cause the data shape change. So I comment this line and the validation part also runs on single GPU. The code runs well.
I am a little confused on why this happens and I do not find how self.student_io_dict and self.teacher_io_dict update in the code, could you help me with that? Maybe changing updates part could better improve this part.
from torchdistill.
Hi @AndyFrancesco29 ,
- the warning is still here. I use single GPU command to run and it is still there. The interesting part is that it happens in stage1 val and stage 2 train+val.
Yes, as I said, it should be a normal behavior but somehow PyTorch began warning it (and maybe next PyTorch release will resolve it) pytorch/pytorch#57273 (comment)
ValueError: Expected input batch_size (19) to match target batch_size (16).`
This looks like caused by different numbers of GPUs are used when training and validating student models (if it's not distributed training, the student model will be on gpu0 in training but on multiple gpus in validation when multiple cuda devices = gpus are visible). You could avoid this problem by adding the following prefix. Just in case, I will update the package to avoid the issue without the prefix.
I check GPU usage it shows that it creates annoying process on different cards (PID 11329):
If you want to force the code to use a specific GPU only, you'll need to replace wrapper: DistributedDataParallel
entry with wrapper:
add prefix like CUDA_VISIBLE_DEVICES=0
before python3 command
e.g., CUDA_VISIBLE_DEVICES=0 python3 examples/image_classification.py --config configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/sskd-resnet20_from_resnet56.yaml --log log/ilsvrc2012/yoshitomo-matsubara/rrpr2020/sskd-resnet20_from_resnet56.txt
from torchdistill.
Thank you very much @yoshitomo-matsubara !
BTW how does self.student_io_dict and self.teacher_io_dict in DistillationBox update? I think it records the output in the assigned layers but I wonder how it works.
from torchdistill.
My pleasure, thank you for reporting the issue.
Yes, a IO dict stores input and/or output of nn.Module you specify in a yaml file, and this is achieved by leveraging forward hook in PyTorch. This notebook (or open it with Google Colab) demonstrates how torchdistill leverages the forward hook to extract input/output of layers in model without requiring users to edit the original model implementation.
FYI, here is the low-level code
from torchdistill.
Related Issues (20)
- Distilling Knowledge from a image classification model with sigmoid function and binary cross entropy HOT 3
- Bug. Bad implement. HOT 2
- Combine two distillation losses HOT 9
- Similarity Preserving KD HOT 2
- How to train my own COCO dataset for object detection? HOT 1
- Why using `log_softmax` instead of `softmax`? HOT 1
- ValueError: batchmean is not a valid value for reduction HOT 1
- Disagreement betweeen the log and configuration of kd-resnet18_from_resnet34 HOT 1
- Use different models as Teacher/Student HOT 1
- Custom Data HOT 1
- Where is trained model? HOT 1
- Not a bug but a discrepency between the log and config file for kd-resnet18_from_resnet34 HOT 1
- How should I use Torchdistill? HOT 1
- [BUG] Not supported to Nvidia 4090 HOT 1
- I tried with this script also, only single nproc seems to be working. Do i need to define any additional enviornment variables like RANK or LocaL HOST HOT 1
- [BUG] fp16 causes AssertionError: No inf checks were recorded for this optimizer HOT 4
- [BUG] Missing Link in Readme HOT 1
- [BUG]ImportError: cannot import name 'import_dependencies' from 'torchdistill.common.main_util' HOT 2
- is tochdistill support knowlede distillation for Vision Foundation Models like Grounding Dino / Grounding DinoSAM ? HOT 1
- [BUG] ModuleNotFoundError: No module named 'torch._six' HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchdistill.