TL;DR:
Simply run
bash run.sh
conda env create -f dd.yml
- Adjusting training profiles:
diffusion_distillation/config/mnist_base.py
- Run:
diffusion_distillation.ipynb
, The output checkpoint will be saved in/tmp/flax_ckpt/checkpoint/
directory. - Copy the checkpoint to a safe folder, as the /tmp folder will be case on system reboot.
- Adjusting sampling profiles:
diffusion_distillation/config/mnist_distill.py
- Run the following code to generate images:
python sample_origin.py --num_imgs 1024 --batchsize 64 --startbatch 0 --db_path data/mnist_origin_debug --ckpt_path /path/to/ckpt
db_path
means the output dataset, which will be used in DSNO-pytorchnum_imgs
means the num of images to be generated