Coder Social home page Coder Social logo

nuistji / improved-ddpm-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from vedantroy/improved-ddpm-pytorch

0.0 0.0 0.0 10.86 MB

Implementation of "Improved Denoising Diffusion Probabilistic Models" in Pytorch

License: MIT License

Shell 0.10% Python 55.56% Jupyter Notebook 44.34%

improved-ddpm-pytorch's Introduction

Improved Denoised Diffusion Probabilistic Models

This is my implementation of Improved Denoising Diffusion Probabilistic Models. It also has an implementation of Denoising Diffusion Implicit Models.

Install

Install the environment with conda, then:

pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu116

We do this because conda's environment.yml cannot specify extra arguments to pip.

Training Statistics

Commit 052985162b85ebf28903c79dc781703507ae99e3

Device Batch Size ba/s samples/s
A100-80GB 302 1.00 300
A100-80GB 128 2.12 271.36

Samples

Trained on CelebHQ for 36000 batches with batch size of 302. Loss of 0.02573 with $L_\text{simple}$ objective.

Features

Implemented:

  • $L_\text{simple}$ objective
  • Cosine schedule
  • Training + Generating
  • $L_\text{hybrid}$ objective / learned variance
    • TODO: Verify sample quality
  • Faster sampling
  • Sampling from DDIM

Unplanned:

  • $L_\text{vlb}$ objective with loss-aware sampler

Repository Guide

This repository is super messy. It has a lot of scratch files that represent my attempts at figuring things out / debugging escapades. Nevertheless, here's the overview:

  • openai-ddm/ contains a forked version of improved-diffusion. I spent a while stripping out features (fp16, checkpointing) to get to the bare-minimum training loop.
  • All training is done using composer which supports fp16, checkpointing, automatic wandb logging, etc.
  • tests/ has some tests that verify my UNet / diffusion is identical to OpenAI's unet/diffusion
    • I haven't written tests for sampling yet
  • run_unet.py prints out the architecture of the UNet in a friendly form & verifies the UNet works
  • make_torchdata.py contains a script to generate a dataset from a folder of image files. It outputs parquet files.

Why?

This repository is for my personal use. But, there are a few nice things you might get from it:

  • The diffusion code is written simply w/o support for a ton of different options in a single class so it's easy to understand
  • Same goes for the UNet
  • (TODO), I will port some comments from my ml-experiments repository to the UNet, so you can see why certain things are done
  • Usage of MosaicML's trainer means you don't get bogged down by FP16/checkpointing/logging & can focus on the important stuff

Credit

  • The "losses.py" and "nn.py" file inside of the "diffusion" folder are copy-pasted from the OpenAI codebase. I haven't had time to re-implement them yet.
  • The for_timesteps function is heavily based off of a function in lucidrain's imagen-pytorch repository.

improved-ddpm-pytorch's People

Contributors

vedantroy avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.