Coder Social home page Coder Social logo

georgecazenavette / mtt-distillation Goto Github PK

View Code? Open in Web Editor NEW
378.0 9.0 51.0 39.52 MB

Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Home Page: https://georgecazenavette.github.io/mtt-distillation/

License: Other

Python 100.00%
computer-vision machine-learning artificial-intelligence synthetic-data

mtt-distillation's People

Contributors

georgecazenavette avatar zhaoguangxiang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mtt-distillation's Issues

A question for the paper

I am very interested in your work, but I have a question: can you directly train a randomly initialized network with the synthetic dataset?
if 10-500 images can train a robust network, that's incredible. Or you have to use raw dataset to help distill images meanwhile train the network. Can you tell me the answer directly?

about hyperparameter: learning rate about updating condenses samples

Hello, George! First of all, I must say that this is very nice work.

I have some doubts about the used hyperparameter lr_img for updating condenses samples. It is not mentioned how to choose lr_img in the paper. Besides, I conduct the experiment about 10 images about each class for CIFAR-10 in terms of Table 6 and only obtain 58.50% accuracy. Should I modify other hyperparameters?

Why grand_loss will gradually increase and become NaN

Hello, George!
First of all, thank you for your wonderful work!
I used the parameters in your article, lr_img=1000, ipc=10, etc. and initialized by sampling noise.
The network uses convnetD3 and ZCA is not used.I did synthetic training on the cifar10 dataset, but why did I get loss=NaN? After careful observation, I found that the pixel values ​​of each synthesized image became NaN. What is the reason for this? I would like to know more details. Thanks

Negative LR

Hi! Thank you for your great work.

When I was distilling with my own dataset, there was very large loss (iter = 0490) and negative learning rate.

Could you help me figure out what is happening here?
What hyperparameters should be adjusted in such case?
Can we implement anything in code to prevent negative LR?

Thank you!

Evaluate 5 random ConvNetD4, mean = 0.2429 std = 0.0080
-------------------------
[2022-08-14 00:29:04] iter = 0400, loss = 1.2390[2022-08-14 00:29:12] iter = 0410, loss = 1.3564
[2022-08-14 00:29:19] iter = 0420, loss = 1.5845
[2022-08-14 00:29:27] iter = 0430, loss = 0.9945
[2022-08-14 00:29:35] iter = 0440, loss = 1.4876
[2022-08-14 00:29:43] iter = 0450, loss = 1.0734
[2022-08-14 00:29:51] iter = 0460, loss = 1.9312
[2022-08-14 00:29:58] iter = 0470, loss = 1.0497
[2022-08-14 00:30:06] iter = 0480, loss = 16.3134
[2022-08-14 00:30:14] iter = 0490, loss = 23.7197
-------------------------
Evaluation
model_train = ConvNetD4, model_eval = ConvNetD4, iteration = 500
DSA augmentation strategy:  color_crop_cutout_flip_scale_rotateDSA augmentation parameters: 
 {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'ratio_noise': 0.05, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5, 'batchmode': False, 'latestseed': -1}Traceback (most recent call last):
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 496, in <module>
    main(args)
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 227, in main
    _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture)
  File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/utils.py", line 400, in evaluate_synset
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
  File "/media/ntu/volume1/home/s121md302_06/anaconda3/envs/distillation/lib/python3.9/site-packages/torch/optim/sgd.py", line 91, in __init__
    raise ValueError("Invalid learning rate: {}".format(lr))
ValueError: Invalid learning rate: -0.00048201243043877184

Grand loss curve

Hi,

I tried to reproduce your method on ResNet18, and I set a proper lr (lr_img=100, lr_lr=1e-5, lr_teacher=0.01) to avoid exploding/vanishing gradient. However, I observed that the grand loss fluctuates around 0.9. Is it normal in this case? Could you please share your grand loss curve for reference?

BR,
Xuyang

How did you get x̄ ± s in table 1

Hi George,

Thanks for your great work, and sorry to bother you again.

I have another question regarding the accuracy value shown in table 1. I assume there are two possible ways to get those numbers.
1, train synthetic data for a certain number of steps (e.g 9000 steps), then test the accuracy on the test dataset.
2, test on test dataset at every 100 steps of training on a synthetic dataset, then take a maximum accuracy.

The second way is not valid since the test dataset should only be used one time in the end.
So did you use the first method to get the accuracy? If so, how many steps did you take?

Thank you, and hope you have a great day!
Dai

What is ZCA

Hi, thanks for your rigorous experiment codes on mtt-distillation.

But there is an undefined abbreviation that I can't figure out. I find an argument "--zca" and dataset with be divided into 2 classes whether or not with ZCA, and what is ZCA?

Experience on hyper-parameters

Dear author,

Thank you for your great solution on dataset distillation! Recently I am working on my own datasets but find that the performance is somewhat sensitive to the hyper parameters. Could you please provide some insights on how to choose the hyper-parameters like syn_steps, expert_epochs, max_start_epoch, learning rate, etc? Thanks in advance!

have trouble at distillating with VGG networks

Hi,

I encountered bloating Synthetic-LR and zero grant loss issue using VGG models.

  • including VGG11, VGG13, VGG16

Similar issues are

But, my experiments are ok with ConvNet and ResNet18 using similar scripts given below.


Here is the snippet of the scripts

SCRIPT_NAME=VGG13
MODEL=VGG13
DATASET=CIFAR10
IMAGE_PER_CLASS=1

python buffer.py --dataset=$DATASET --model=$MODEL --train_epochs=50 --num_experts=100 --buffer_path=$BUFFER_PATH --data_path=$DATA_PATH >> ./results/buffer_$SCRIPT_NAME.txt

python distill.py --dataset=$DATASET --model=$MODEL --ipc=$IMAGE_PER_CLASS --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path=$BUFFER_PATH --data_path=$DATA_PATH >> ./results/distill_$SCRIPT_NAME.txt
  • VGG11 is using CIFAR100 and --ipc=10
  • all models are not using --zca whitening

image
image

Either the Synthetic-LR goes to extremely positive or negative at very begining.

Thank you.

GPU requirement

Thanks for your great work in distilled datasets. I was wondering to know about your hardware setup for CIFAR100, tinyImagenet, and Imagenet(subset). How long did you need for your results (generating the experts and distillation step).

Reproduce cross-architecture performance

Hi George,
Thanks for your inspiring and great work.

I would like to reproduce the cross-architecture accuracy. But I'm having difficulty to have a accuracy which is comparable to the accuracy listed in the paper. I think I might be missing some details. Could you please type out the command you used to produce the cross-architecture performance of Cifar 10 with 10 img/cls?

Here is the command I used:
First step: python buffer.py --dataset=CIFAR10 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path=buffer --data_path=data
Second step: python3 distill.py --dataset=CIFAR10 --ipc=10 --syn_steps=30 --expert_epochs=2 --max_start_epoch=15 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path=buffer --data_path=data --eval_mode='M' --eval_it=1 --Iteration=300

Is there some thing I'm missing here?
In addition, did you change the parser augment epoch_eval_train when you produce the SOTA cross-architecture results?

Thank you!
Looking forward to your reply!

how to do distillation for a model other than VGG?

Hi & good morning!

I have two model and would like to see what are the distilled images if I use them. So, let's say model A and B. Is it possible to use A and B and see and save 100 of the distilled cifar10 images? If so, what command should I run?

Thank you very much

values for max_start_epoch

Hi there, I can see that max_start_epoch is set to 20. However, during the generation of the expert trajectories, train_epochs is 50. It means that during distillation, we don't use most of the saved checkpoints (>20+3). My questions:

  1. Is there any reason to choose max_start_epoch as 20 not 50?
  2. Can we make train_epochs to a lower value so to reduce training time?

how to use images?

hello, i wanna know how to use distilled images. I used distilled images to train a new network, but the accuracy was terrible(10% on cifar10). So, can these images be used to train a new network? if not, what's the meanning of these images. if these images can train a new network, can you share me the network architecture.

details about the full real subsets of ImageNet

Hi,

I am interested about your work, and you use a special dataset called subsets of ImageNet. From the rare information in the paper and this repo, I know in each subsets of ImageNet,

  • the number of class is 10
  • the resolution is 112x112
  • what is the number of images in each class?

Running this project with PCAM instead of CFAR

Hello, I run this project with CFAR and everything works as it should be. I am new to machine learning and am not sure what to change to run this code with PCAM instead of CFAR. I need this change for my university project. Could you suggest me what should I change to run this with PCAM?

A question about backbone networks

Hi I've taken great interest in your work and am trying to experiment on various environments.

1

From the table 1. in the paper you show that the ConvNet used as a baseline only shows a maximum of 56.2% accuracy even when trained with a full CIFAR100 training set which is considerably lower compared to the SOTA classification models with higher than 90% accuracy.

As the performance of the baseline model or expert trajectories trained on the full dataset serves as a upper bound for the performance of the student network trained on the synthetic dataset I was wondering if you ever experimented on more complex networks like WideResNet50 from the point of training expert networks . If you haven't do you have any naive guesses to what the outcome would be?

Thanks a bunch.

About Hyper-paramters

Hi, sorry to bother.
I really appreciate that this is a wonderful work, but I'm wondering how do you confirm the hyper-parameters of the distill.py?
e.g. lr_img, lr_lr, lr_teacher, max_start_epoch, and so on.
Do you use the grid search on the hyper parameters? If so, how do you implement it? Since I notice that there are a number of hyper-parameters needing to be set, I think it would take quite a few time to get a good setting.
Thanks!

The clip value

Hi thanks for your great work!
I am curious about the clip_val. Why do you choose 2.5? why clipping needed? Could you please explain a little bit? Thanks!
And when training with distilled data, we don't need clipping, right?

for clip_val in [2.5]:
    std = torch.std(images_train)
    mean = torch.mean(images_train)
    upsampled = torch.clip(images_train, min=mean-clip_val*std, max=mean+clip_val*std)```

how does it work on a new model arch?

Hi, have a question:

learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

If I using a totally different model arch train on the distilled dataset, will it also work on any arch model? The performance still same as training on original dataset? How?

Model training with the released synthetic CIFAR-10 (ipc=50)

Hi George,

I attempt to directly use the release synthetic dataset of CIFAR-10 (ipc=50) to train the default ConvNet, but the final test accuracy is only 55.67% on the orinigal test set of CIFAR-10. The training details are as follows:

  • I have not do any transformation w.r.t. the synthetic images (btw, the CIFAR-10 (IPC-50) does not seem to require ZCA preprocessing). Address of downloading the dataset is https://georgecazenavette.github.io/mtt-distillation/tensors/index.html#tensors/cifar10_50
    • Augmentation is used for synthetic images to train the nework as the same as it shown in the function of augment()
  • architecture: default ConvNet
  • test set preprocessing: transforms.ToTensor() and transforms.Normalize(mean, std)
  • learning rate: initial lr=0.01 and decays to 0.001 at the 500th epoch
  • the number of training epochs: 1000
  • optimizer: SGD with momentum=0.9 and weight_decay=5e-4
  • loss function: cross entropy
  • GPU: NVIDIA TITAN V

Do I miss other implementation details?

Best,
Shiye

question about learning rate

HI, sorry to bother you
I want to know why the lr_img learning rate is set to 1000,How did you determine 1000?
Because usually the learning rate is used 0.1, 0.01, 0.001。

and if i just change lr_img to 100 or smaller value, then the loss became nan.

Can you tell me how it works and give some advice about setting Hyperparameter? (i know The Hyperparameters given now can reproduce the effect in the paper. but i want to use other network to distill datasets, i think The hyperparameters must be modified when using different networks.)

ReparamModule Usage

Hi George,

Thanks for your amazing work!

I find a module called 'ReparamModule' in your code, without too much explaination. It seems to flatten or unflatten parameters. I think its a supporting module for training of student models without using 'loss.backward()'. Could u give more explanation about this part?

Can I use this ReparamModule in language models like Transformers or Bert? Or which part should I change? I am a new grad student in NLP and RS, your answer would be a great help.

Much Thanks.

Expert trajactory performance

Thanks for your work! I've got a question. When training the expert trajactory with CiFAR10 accroding to buffer.py, I only got test accuracy around 0.79 and 0.77 w/o --zca after 50 epochs. However, Table 1 in your paper reports that full dataset can reach 0.84 accuracy on CiFAR10. Is there any mistake I've made here?

N << M or N >> M

Hi,

I notice that in the paper's picture of trajectory matching, you noted that N << M.
But in code, the hyperparameters you used are mostly N >> M, such that N is often 20 or 30, and M is often 2 or 3.
Is this a typo, or I misunderstood?

Sincerely,
Haowen Guan

Normalization of dataset

Hi @GeorgeCazenavette

Hope that all is well. I'm trying to load images via images = torch.load("imagefruit/images_best.pt"). However, when I try to plot them, I get the error of

WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

How can I normalize the data? I tried minmax normalization but the outputs are opaque and not similar to the ones that you have in the here.

Thanks

CIFAR 10 ipc 10 Hyperparameters

Hi George,

Thanks for your amazing work!

I am testing your codes in the setting of CIFAR10 ipc 10. I adopt the hyperparameters list in the readme page. However, gradient explosion ocurrs on parameter 'syn_lr'. By changing the learning rate (pxiel) to 100 or 1000, I can get a stable result. I wonder if there is a typo in the table.

Your anwer will be really helpful!

Much Thanks.

Where did you get the acc 36.1% from the paper Dataset Distillation with Infinitely Wide Convolutional Networks

Thanks for your great idea and detailed work, and I hope you are enjoying your day so far.

I have a question regarding your paper "Dataset Distillation by Matching Training Trajectories". In the third sentence count from the bottom of the Introduction, you stated you break SOTA "Dataset Distillation with Infinitely Wide Convolutional Networks" on his accuracy of 36.1%/46.5%, However, the accuracy stated in the paper is actually 64.7%/80.6%.

Is that a small mistake? If it's not, could you help me to address where on the paper you find the accuracy?
Thank you and best regards!

Checkpoints of models

Good day @GeorgeCazenavette I was reading this "Distillation by Matching Training Trajectories" paper and a more recent paper of yours, and looked into the both repos. I was wondering if by any chance you have stored the checkpoint of models that are Trained from scratch on the whole dataset? Like the Table 8 model checkpoints form the GLaD paper?
Thanks in advance.
Flora

Unrolled optimization

Hi!

Do I understand correctly that the grand loss at the end will backprop through grad of grad of grad, e.g. not double backward but 20th order backward?

I.e. student_params[5] depends on student_params[4] and grad(loss(target; student_params[4]) and same goes further and we'll have in the computation branch a path that goes through all 5 grad computations

args.mom in buffer.py

Hi, I found there is a momentum in teacher_optim named args.mom in buffer.py but I didn't find the value of it. Could you pls provide it?
Many thanks

distill.py loss = nan

Hello, author. Thank you for your work.!
Running distill During py, loss is always Nan. What parameters do the author suggest to adjust? Or did I ignore what caused the error?
In addition: I use my own dataset. The experimental settings and dataset settings are shown in the figure below.
image
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.