Coder Social home page Coder Social logo

dsno-pytorch's Introduction

Fast sampling of diffusion models via operator learning

TL;DR:

Simply run

bash run.sh

Env

conda env create -f new.yml

Train

  1. Adjusting training profiles: configs/mnist10-dsno-t4.yaml, which you need to change datapath and make sure it is end with /lmdb(generated in diffusion_distillation).
  2. run: python3 train_mnist.py --config configs/mnist10-dsno-t4.yaml --num_gpus 1

Sampling

  1. run: python3 generate_mnist.py --config configs/mnist10-dsno-t4.yaml --ckpt /path/to/ckpt --num_imgs 10

Code for Preprint

drawing

Requirements

Docker:

  • Ensure that NVIDIA container runtime is installed.
  • Use the Dockerfile in this repo.

Generated trajectory data

Results

CIFAR10-original data CIFAR10-DSNO ImageNet64-original data ImageNet64-DSNO
FID 2.51 3.78 2.70 7.83
Recall - - - 6.119

Train

Unconditional generation model. The following run gave the best FID of 3.78 on CIFAR10 within 400k iterations (72 hours on 8 A100 GPUs).

python3 train_cifar.py --config configs/cifar10-dsno-t4.yaml --num_gpus 8

Conditional generation model. The experiments on ImageNet64 require multi-node training (6-7 days on 64 V100 GPUs with mixed precision training). The following script is just an example on single node with mixed precision.

python3 train_imagenet.py --config configs/imagenet64-dsno-t4.yaml --num_gpus_per_node 8 --amp

PS: you can add --log and configure the log.entity key in the corresponding yaml file to turn on wandb for logging.

Evaluation

We use the EDM's evaluation code to report FID. We use ADM evaluation code to report Recall.

Example of class-conditional generation:

python3 generate_imagenet.py --config configs/imagenet64-dsno-t4.yaml --ckpt [path to checkpoint]

Code structure

│   Dockerfile
│   README.md
│   train_cifar.py
│   train_imagenet.py
|   generate_imagenet.py          # example of class-conditional generation
|   generate_cifar.py             # example of unconditional generation
│   
├───configs
│       cifar10-dsno-t4.yaml      # example of configuration file for unconditional dsno on cifar10
│       imagenet64-dsno-t4.yaml   # example of configuration file for class-conditional dsno on ImageNet64
│       
├───models
│       layers.py
│       layersmt.py
│       tddpmm.py       # architecture of dsno
│       up_or_down_sampling.py
│       utils.py
│
└───utils
        dataset.py
        data_helper.py
        distributed.py
        helper.py
        loss.py

dsno-pytorch's People

Contributors

javazeroo avatar zongyi-li avatar devzhk avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.