Coder Social home page Coder Social logo

gan's Introduction

Code accompanying our Dist-GAN and GN-GAN (update soon) papers.

Setup

Dependencies

Python 2.7 or 3.x, Numpy, scikit-learn, Tensorflow, Keras (1D demo)
For Python 3.x - import pickle in the file distgan_image/modules/dataset.py

Getting Started

We conduct experiments of our model with 1D/2D synthetic data, MNIST, CelebA, CIFAR-10 and STL-10 datasets.

1D demo

In addition to Dist-GAN, other methods, such as GAN, MDGAN, VAEGAN, WGAN-GP are provided in our 1D code.

>> cd distgan_toy1d
>> python gan_toy1d.py

Quick video demos, you can reproduce easily these videos with our code:

GAN
WGANGP (WGAN-GP can match data distribution at some time, but diverged later)
VAEGAN
Dist-GAN

The visualization part of our 1D code is re-used from here:

2D synthetic data

>> cd distgan_toy2d
>> python distgan_toy2d.py

We provide three different data layouts you can test on: 'SINE' 'PLUS' 'SQUARE'. Just change the parameter testcase in the code gaan_toy2d.py. For example:

testcase      = 'SQUARE'

Image data (MNIST, CelebA and CIFAR-10)

We provide our code for image datasets, such as: MNIST, CelebA and CIFAR-10.

MNIST
>> cd distgan_image
>> python distgan_mnist.py

Generated samples Real samples

From left to right: Samples generated by our Dist-GAN model (DCGAN for MNIST) and real samples.

CIFAR-10

Downloading cifar-10 from 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' and extracting it into the correct folder: eg. ./data/cifar10/

>> cd distgan_image
>> python distgan_cifar.py

Generated samples Real samples

From left to right: Samples generated by our Dist-GAN model (DCGAN) and real samples.

CelebA

Downloading CelebA from: https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg and extract into the correct folder: eg. ./data/celeba/

>> cd distgan_image
>> python distgan_celeba.py

Generated samples Real samples

From left to right: Samples generated by our Dist-GAN model (DCGAN) and real samples.

STL-10

Downloading STL-10 from: http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz and extract it into a folder. Then, using ./modules/create_stl10.py to read binary file unlabeled_X.bin and save images into a specific folder: eg. ./data/stl-10/

>> cd distgan_image
>> python distgan_stl10.py

Generated samples Real samples

From left to right: Samples generated by our Dist-GAN model (with standard CNN architecture [2] + hinge loss) and real samples (right).

Results

FID scores of Dist-GAN for CIFAR-10 and STL-10 datasets are summarized, following experiment setup of [2] with standard CNN architecture (FID is computed with 10K real images and 5K generated images). Dist-GAN is trained with 300K iterations.

Method CIFAR-10 STL-10
WGAN-GP [1] 40.2 55.1
SN-GAN [2] 29.3 53.1
SN-GAN (hinge loss) [2] 25.5 43.2
SN-GAN (ResNet) [2] 21.70 +- .21 -
Dist-GAN 28.23 -
Dist-GAN (hinge loss) 22.95 36.19 (100K)
Dist-GAN (ResNet) 17.61 +- .30 28.50 +- .49

Citation

If you find this work useful in your research, please consider citing:

@InProceedings{Tran_2018_ECCV,
  author = {Tran, Ngoc-Trung and Bui, Tuan-Anh and Cheung, Ngai-Man},
  title = {Dist-GAN: An Improved GAN using Distance Constraints},
  booktitle = {The European Conference on Computer Vision (ECCV)},
  month = {September},
  year = {2018}
}

or

@article{tran-2018-abs-1803-08887,
  author    = {Ngoc{-}Trung Tran and
               Tuan{-}Anh Bui and
               Ngai{-}Man Cheung},
  title     = {Dist-GAN: An Improved GAN using Distance Constraints},
  journal   = {CoRR},
  volume    = {abs/1803.08887},
  year      = {2018}
}

Updates

  • 2018/06/18: Dist-GAN supports standard CNN architecture like SN-GAN [2]. New FID results of standard CNN (+ hinge loss) are added.
  • 2018/12/15: Dist-GAN supports ResNet architecture like WGAN-GP [1], SN-GAN [2] (modified from ResNet code of WGAN-GP [1]).
  • 2019/04/05: Dist-GAN supports another ResNet architecture (modified from ResNet code of SAGAN [3]).

References

[1] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville, "Improved Training of Wasserstein GANs", NIPS 2017.
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, "Spectral Normalization for Generative Adversarial Networks", ICLR 2018.
[3] Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena, "Self-Attention Generative Adversarial Networks", arXiv preprint arXiv:1805.08318 2018.

gan's People

Contributors

tntrung avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

gan's Issues

about lambda_w

hello guys,
correct me if i'm wrong.
in the paper, you have f(x,G(z)) and lambda_w = sqrt(dim_z/dim_x)
but i glanced at the code, you guys used features of x and G(z). should lambda_w be fixed to
lambda_w = sqrt(dim_z/dim_ft_of_x)?
thank you!

how to generate new data after fit(data)?

In 1D demo , after model = GAN(...).fit(data)
how can I use this model to generate new data?
I have try to feed z to model._create_generator(z) but I got
AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

f(x, G(z)) computation

Hi, Trung,

In gaan.py,
self.md_x = tf.reduce_mean(self.f_recon - self.f_fake)
According to Eq. (7) in your paper, maybe
self.md_x = tf.reduce_mean(self.f_real - self.f_fake)

Is it correct?
Thanks,
Sungwoong.

gradient penalty computation

Hi, Ngoc-Trung,

Thanks for your code sharing.
I have a question regarding a computation of gradient penalty in your code.

In gaan.py,
epsilon = tf.random_uniform(shape=tf.shape(self.X), minval=0., maxval=1.)

I think for the convex combination for each sample (same epsilon should be applied to all dims in each sample),
epsilon = tf.random_uniform(shape=[tf.shape(self.X)[0],1], minval=0., maxval=1.)

Is it correct?

Thanks,
Sungwoong.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.