Coder Social home page Coder Social logo

Comments (9)

shwinshaker avatar shwinshaker commented on June 4, 2024 1

Hi @yoshitomo-matsubara,

I am sorry I didn't notice this in the readme. I will move to the discussion next time.

And thanks for the link. I just want to confirm that the following config means

FinalLoss = factor(2.0) * KDLoss + factor(0.8) * CRDLoss,
where
KDLoss = alpha * CrossEntropy(logits, label) + (1 - alpha) * KLDiv(logits, teacher_predictions).

Is my understanding correct?

criterion:
    type: 'GeneralizedCustomLoss'
    org_term:
      criterion:
        type: 'KDLoss'
        params:
          temperature: 4.0
          alpha: 0.5
          reduction: 'batchmean'
      factor: 2.0
    sub_terms:
      crd:
        criterion:
          type: 'CRDLoss'
          params:
            teacher_norm_module_path: 'normalizer'
            student_norm_module_path: 'normalizer'
            student_empty_module_path: 'empty'
            input_size: *feature_dim
            output_size: &num_samples 1281167
            num_negative_samples: *num_negative_samples
            num_samples: *num_samples
            temperature: 0.07
            momentum: 0.5
            eps: 0.0000001
        factor: 0.8

from torchdistill.

shwinshaker avatar shwinshaker commented on June 4, 2024 1

Yeah sure, thanks again

from torchdistill.

yoshitomo-matsubara avatar yoshitomo-matsubara commented on June 4, 2024

Hi @shwinshaker

For questions, please use Discussions tab above. As explained in README, I want to keep Issues mainly for bug reports.

For CRD + KD, you can refer to this part
https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/sample/ilsvrc2012/single_stage/crd/resnet18_from_resnet50.yaml#L128-L151

If I remember correctly, the CRD paper doesn't well describe the implementation of CRD+KD (e.g., is it multi-stage? if single-stage and just a linear combination, how did they weight each loss value?), so I assume that it's a single-stage training and linear combination of CRD and KD losses.

You can refer to more examples available in sample configs such as https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/sample/ilsvrc2012/single_stage/ghnd/custom_resnet152_from_resnet152.yaml#L105-L152

from torchdistill.

yoshitomo-matsubara avatar yoshitomo-matsubara commented on June 4, 2024

Hi @shwinshaker

Yes, that is correct. Note that org_term uses output from original (teacher, student) models while sub_terms does not necessarily (e.g., usually loss terms defined under sub_terms use intermediate features of models)

Also, you will need to tune parameters in configs under samples/ as most of them were added for testing purpose.

By the way, I personally do not recommend using CRD as it's pretty time-consuming and other methods could achieve better accuracy with much lower training cost as reported in my paper.

from torchdistill.

shwinshaker avatar shwinshaker commented on June 4, 2024

Hi @yoshitomo-matsubara,

Many thanks for the explanation. I have read your paper carefully and I think it probably would be better if we can have a document explaining the keywords here as I found there are only a few explanations in your paper.

And for the efficiency issue. I noticed that in your paper you mentioned the batch size for the augmented dataset is shrunken to fit into the GPU memory, thus costing efficiency for methods like CRD and SSKD. I just wonder if I'd like to use their original settings, how can I set the config here properly? I looked up carefully but I didn't find a keyword corresponding to the batch size for the augmented samples.

from torchdistill.

yoshitomo-matsubara avatar yoshitomo-matsubara commented on June 4, 2024

Hi @shwinshaker

I have been having documentation for torchdistill in mind for a long time, but couldn't make time during my PhD. So it is still work in progress, but at least I did my best to make most of the parameters in yaml are self-explanatory.

By augmented samples, if you mean negative samples in CRD, then it's here https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/crd-resnet18_from_resnet34.yaml#L67

For SSKD, the augmented samples are as 3x big as the batch size.
If you want to use different batch size, you can just update batch_size: in yaml config files
https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/crd-resnet18_from_resnet34.yaml#L62

Note that this is batch size per process. If you use distributed training mode (e.g., DDP instead of DP), the batch size in yaml should be total_batch_size / num_distributed_processes
i.e., if the total batch size in some paper is 256 and you use 8 GPUs in parallel in distributed training mode, then you'd put batch_size: 32

Personally, I do not think reverting the batch size dramatically improves accuracy and helps you reproduce the original numbers in their papers since in both the repos, scripts for ImageNet are not available and people complain that it's not reproducible (some people like me even tried, but they reported that they failed to reproduce the results)

from torchdistill.

shwinshaker avatar shwinshaker commented on June 4, 2024

Hi @yoshitomo-matsubara,

Thanks for the very detailed explanation.
I very much have the same feeling that time is limited during PhD so I can understand documentation is really hard work. But you are responding very fast so I feel it is even better than a doc! Thanks again!

I am now experimenting with your code and it runs smoothly. Great work!
But I probably have one final question. I am trying just the plain kd, but I found there is a difference between the following configs.
For the teacher model, the first one set wrapper = DistributedDataParallel while the second one set wrapper = DataParallel. If I would like to set resnet34 as the teacher which one should I use to maximize the training speed?

from torchdistill.

yoshitomo-matsubara avatar yoshitomo-matsubara commented on June 4, 2024

Hi @shwinshaker

Thank you for the words :)

But I probably have one final question. I am trying just the plain kd, but I found there is a difference between the following configs. For the teacher model, the first one set wrapper = DistributedDataParallel while the second one set wrapper = DataParallel. If I would like to set resnet34 as the teacher which one should I use to maximize the training speed?

To speed up training, I'd highly suggest using distributed training mode like this if you have multiple GPUs.

In general, if your teacher models have no updatable weights (i.e., completely frozen), then you should use DataParallel because DistributedDataParallel (or DDP) does not work with such models.

Even if you choose DDP in yaml files, torchdistill internally replaces DistributedDataParallel with DataParallel if 1) your training is in non-distributed training mode or 2) your model has no updatable weights.

The first file was create long time ago, and maybe I forgot to replace DistributedDataParallel with DataParallel for teacher.

from torchdistill.

yoshitomo-matsubara avatar yoshitomo-matsubara commented on June 4, 2024

@shwinshaker Could you close this issue and open a new Discussion if you still have questions?

from torchdistill.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.