Coder Social home page Coder Social logo

elijahahianyo / residual-flows Goto Github PK

View Code? Open in Web Editor NEW

This project forked from rtqichen/residual-flows

0.0 0.0 0.0 230 KB

code for "Residual Flows for Invertible Generative Modeling".

Home Page: https://arxiv.org/abs/1906.02735

License: MIT License

Python 100.00%

residual-flows's Introduction

Residual Flows for Invertible Generative Modeling [arxiv]

Building on the use of Invertible Residual Networks in generative modeling, we propose:

  • Unbiased estimation of the log-density of samples.
  • Memory-efficient reformulation of the gradients.
  • LipSwish activation function.

As a result, Residual Flows scale to much larger networks and datasets.

Requirements

  • PyTorch 1.0+
  • Python 3.6+

Preprocessing

ImageNet:

  1. Follow instructions in preprocessing/create_imagenet_benchmark_datasets.
  2. Convert .npy files to .pth using preprocessing/convert_to_pth.
  3. Place in data/imagenet32 and data/imagenet64.

CelebAHQ 64x64 5bit:

  1. Download from https://github.com/aravindsrinivas/flowpp/tree/master/flows_celeba.
  2. Convert .npy files to .pth using preprocessing/convert_to_pth.
  3. Place in data/celebahq64_5bit.

CelebAHQ 256x256:

# Download Glow's preprocessed dataset.
wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
tar -C data/celebahq -xvf celeb-tfr.tar
python extract_celeba_from_tfrecords

Density Estimation Experiments

NOTE: By default, O(1)-memory gradients are enabled. However, the logged bits/dim during training will not be an actual estimate of bits/dim but whatever scalar was used to generate the unbiased gradients. If you want to check the actual bits/dim for training (and have sufficient GPU memory), set --neumann-grad=False. Note however that the memory cost can stochastically vary during training if this flag is False.

MNIST:

python train_img.py --data mnist --imagesize 28 --actnorm True --wd 0 --save experiments/mnist

CIFAR10:

python train_img.py --data cifar10 --actnorm True --save experiments/cifar10

ImageNet 32x32:

python train_img.py --data imagenet32 --actnorm True --nblocks 32-32-32 --save experiments/imagenet32

ImageNet 64x64:

python train_img.py --data imagenet64 --imagesize 64 --actnorm True --nblocks 32-32-32 --factor-out True --squeeze-first True --save experiments/imagenet64

CelebAHQ 256x256:

python train_img.py --data celebahq --imagesize 256 --nbits 5 --actnorm True --act elu --batchsize 8 --update-freq 5 --n-exact-terms 8 --fc-end False --factor-out True --squeeze-first True --nblocks 16-16-16-16-16-16 --save experiments/celebahq256

Pretrained Models

Model checkpoints can be downloaded from releases.

Use the argument --resume [checkpt.pth] to evaluate or sample from the model.

Each checkpoint contains two sets of parameters, one from training and one containing the exponential moving average (EMA) accumulated over the course of training. Scripts will automatically use the EMA parameters for evaluation and sampling.

BibTeX

@inproceedings{chen2019residualflows,
  title={Residual Flows for Invertible Generative Modeling},
  author={Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David and Jacobsen, J{\"{o}}rn{-}Henrik},
  booktitle = {Advances in Neural Information Processing Systems},
  year={2019}
}

residual-flows's People

Contributors

rtqichen avatar luchengthu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.