Comments (9)
I am sorry I didn't notice this in the readme. I will move to the discussion next time.
And thanks for the link. I just want to confirm that the following config means
FinalLoss = factor(2.0) * KDLoss + factor(0.8) * CRDLoss,
where
KDLoss = alpha * CrossEntropy(logits, label) + (1 - alpha) * KLDiv(logits, teacher_predictions).
Is my understanding correct?
criterion:
type: 'GeneralizedCustomLoss'
org_term:
criterion:
type: 'KDLoss'
params:
temperature: 4.0
alpha: 0.5
reduction: 'batchmean'
factor: 2.0
sub_terms:
crd:
criterion:
type: 'CRDLoss'
params:
teacher_norm_module_path: 'normalizer'
student_norm_module_path: 'normalizer'
student_empty_module_path: 'empty'
input_size: *feature_dim
output_size: &num_samples 1281167
num_negative_samples: *num_negative_samples
num_samples: *num_samples
temperature: 0.07
momentum: 0.5
eps: 0.0000001
factor: 0.8
from torchdistill.
Yeah sure, thanks again
from torchdistill.
Hi @shwinshaker
For questions, please use Discussions
tab above. As explained in README, I want to keep Issues
mainly for bug reports.
For CRD + KD, you can refer to this part
https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/sample/ilsvrc2012/single_stage/crd/resnet18_from_resnet50.yaml#L128-L151
If I remember correctly, the CRD paper doesn't well describe the implementation of CRD+KD (e.g., is it multi-stage? if single-stage and just a linear combination, how did they weight each loss value?), so I assume that it's a single-stage training and linear combination of CRD and KD losses.
You can refer to more examples available in sample configs such as https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/sample/ilsvrc2012/single_stage/ghnd/custom_resnet152_from_resnet152.yaml#L105-L152
from torchdistill.
Hi @shwinshaker
Yes, that is correct. Note that org_term
uses output from original (teacher, student) models while sub_terms
does not necessarily (e.g., usually loss terms defined under sub_terms
use intermediate features of models)
Also, you will need to tune parameters in configs under samples/
as most of them were added for testing purpose.
By the way, I personally do not recommend using CRD as it's pretty time-consuming and other methods could achieve better accuracy with much lower training cost as reported in my paper.
from torchdistill.
Many thanks for the explanation. I have read your paper carefully and I think it probably would be better if we can have a document explaining the keywords here as I found there are only a few explanations in your paper.
And for the efficiency issue. I noticed that in your paper you mentioned the batch size for the augmented dataset is shrunken to fit into the GPU memory, thus costing efficiency for methods like CRD and SSKD. I just wonder if I'd like to use their original settings, how can I set the config here properly? I looked up carefully but I didn't find a keyword corresponding to the batch size for the augmented samples.
from torchdistill.
Hi @shwinshaker
I have been having documentation for torchdistill in mind for a long time, but couldn't make time during my PhD. So it is still work in progress, but at least I did my best to make most of the parameters in yaml are self-explanatory.
By augmented samples
, if you mean negative samples in CRD, then it's here https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/crd-resnet18_from_resnet34.yaml#L67
For SSKD, the augmented samples are as 3x big as the batch size.
If you want to use different batch size, you can just update batch_size:
in yaml config files
https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/crd-resnet18_from_resnet34.yaml#L62
Note that this is batch size per process. If you use distributed training mode (e.g., DDP instead of DP), the batch size in yaml should be total_batch_size / num_distributed_processes
i.e., if the total batch size in some paper is 256 and you use 8 GPUs in parallel in distributed training mode, then you'd put batch_size: 32
Personally, I do not think reverting the batch size dramatically improves accuracy and helps you reproduce the original numbers in their papers since in both the repos, scripts for ImageNet are not available and people complain that it's not reproducible (some people like me even tried, but they reported that they failed to reproduce the results)
from torchdistill.
Thanks for the very detailed explanation.
I very much have the same feeling that time is limited during PhD so I can understand documentation is really hard work. But you are responding very fast so I feel it is even better than a doc! Thanks again!
I am now experimenting with your code and it runs smoothly. Great work!
But I probably have one final question. I am trying just the plain kd
, but I found there is a difference between the following configs.
For the teacher model, the first one set wrapper = DistributedDataParallel
while the second one set wrapper = DataParallel
. If I would like to set resnet34
as the teacher which one should I use to maximize the training speed?
from torchdistill.
Hi @shwinshaker
Thank you for the words :)
But I probably have one final question. I am trying just the plain
kd
, but I found there is a difference between the following configs. For the teacher model, the first one setwrapper = DistributedDataParallel
while the second one setwrapper = DataParallel
. If I would like to setresnet34
as the teacher which one should I use to maximize the training speed?
To speed up training, I'd highly suggest using distributed training mode like this if you have multiple GPUs.
In general, if your teacher models have no updatable weights (i.e., completely frozen), then you should use DataParallel
because DistributedDataParallel (or DDP) does not work with such models.
Even if you choose DDP in yaml files, torchdistill internally replaces DistributedDataParallel
with DataParallel
if 1) your training is in non-distributed training mode or 2) your model has no updatable weights.
The first file was create long time ago, and maybe I forgot to replace DistributedDataParallel
with DataParallel
for teacher.
from torchdistill.
@shwinshaker Could you close this issue and open a new Discussion
if you still have questions?
from torchdistill.
Related Issues (20)
- Affinity Loss usage HOT 2
- It seems some bug in `split_dataset` HOT 1
- Distilling Knowledge from a image classification model with sigmoid function and binary cross entropy HOT 3
- Bug. Bad implement. HOT 2
- Similarity Preserving KD HOT 2
- How to train my own COCO dataset for object detection? HOT 1
- Why using `log_softmax` instead of `softmax`? HOT 1
- ValueError: batchmean is not a valid value for reduction HOT 1
- Disagreement betweeen the log and configuration of kd-resnet18_from_resnet34 HOT 1
- Use different models as Teacher/Student HOT 1
- Custom Data HOT 1
- Where is trained model? HOT 1
- Not a bug but a discrepency between the log and config file for kd-resnet18_from_resnet34 HOT 1
- How should I use Torchdistill? HOT 1
- [BUG] Not supported to Nvidia 4090 HOT 1
- I tried with this script also, only single nproc seems to be working. Do i need to define any additional enviornment variables like RANK or LocaL HOST HOT 1
- [BUG] fp16 causes AssertionError: No inf checks were recorded for this optimizer HOT 4
- [BUG] Missing Link in Readme HOT 1
- [BUG]ImportError: cannot import name 'import_dependencies' from 'torchdistill.common.main_util' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchdistill.