Coder Social home page Coder Social logo

dirac-diffusion's Introduction

DiracDiffusion: Denoising and Incremental Reconstruction with Assured Data-Consistency (ICML 2024)

This is the official repository for the paper DiracDiffusion: Denoising and Incremental Reconstruction with Assured Data-Consistency.

DiracDiffusion: Denoising and Incremental Reconstruction with Assured Data-Consistency,
Zalan Fabian, Berk Tınaz, Mahdi Soltanolkotabi
ICML 2024

We introduce a novel framework for solving inverse problems using a generalized notion of diffusion that mimics the corruption process of a particular image degradation. Our method maintains data consistency throughout the reverse diffusion process allowing us to early-stop during reconstruction and flexibly trade off perceptual quality for lower distortion.

Image 1 Image 2 Image 3 Image 4

Requirements

CUDA-enabled GPU is necessary to run the code. We tested this code using:

  • A6000 GPUs
  • Ubuntu 20.04
  • CUDA 12.2
  • Python 3.10

Setup

  1. In a clean virtual environment run
git clone https://github.com/z-fabian/dirac-diffusion
cd dirac-diffusion
pip install -r requirements.txt
  1. Configure the root folder for each dataset in the dataset config file.
  2. (Optional) Download pre-trained checkpoints into the checkpoints directory within the repository.

Data

We evaluate Dirac on CelebA-HQ (256x256), FFHQ (256x256) and ImageNet. Our code expects the datasets in the following library structure:

  • CelebA-HQ: the root folder contains all files directly, numbered 00000.jpg - 30000.jpg.
  • FFHQ: the root folder contains subfolders 00000 - 00069 each with 1000 images.
  • ImageNet: the root folder contains train and val folders.

Pre-trained models

Train dataset Operator Training loss Checkpoint size Link
CelebA-HQ Gaussian blur $\mathcal{L}_{IR}(\Delta t=0.0; \theta)$ 776M Download
CelebA-HQ Gaussian blur $\mathcal{L}_{IR}(\Delta t=1.0; \theta)$ 776M Download
CelebA-HQ Inpainting $\mathcal{L}_{IR}(\Delta t=0.0; \theta)$ 776M Download
ImageNet Gaussian blur $\mathcal{L}_{IR}(\Delta t=1.0; \theta)$ 776M Download
ImageNet Inpainting $\mathcal{L}_{IR}(\Delta t=0.0; \theta)$ 5.8G Download

Model training

To train an incremental reconstruction model from scratch, run

python scripts/train_dirac.py fit --config PATH_TO_CONFIG

and replace PATH_TO_CONFIG with the trainer config file. See trainer configs here for deblurring and inpainting experiments. In case your GPU doesn't support mixed-precision training, change precision to fp32 in the config file. You can train models with different architectural hyperparameters by adding a new key to the model config, and specifying the same key in the trainer config under model_arch.

Reconstruction

To reconstruct images, run

python scripts/recon_dirac.py --config_path PATH_TO_CONFIG --dataset DATASET_NAME

and replace PATH_TO_CONFIG with the reconstruction config file and DATASET_NAME with the name of the dataset to be reconstructed (either 'celeba256', 'ffhq' or 'imagenet'). You can find config files for both perception-optimized and distortion-optimized reconstructions here. Take a look at the annotated reference config to see all the options.

Citation

If you find our paper useful, please cite

@inproceedings{fabian2023diracdiffusion,
  title={Diracdiffusion: Denoising and incremental reconstruction with assured data-consistency},
  author={Fabian, Zalan and Tinaz, Berk and Soltanolkotabi, Mahdi},
  booktitle={Forty-first International Conference on Machine Learning},
  year={2023}
}

Acknowledgments

This repository builds upon code from

dirac-diffusion's People

Contributors

z-fabian avatar berktinaz avatar

Stargazers

Wang Chong avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.