Coder Social home page Coder Social logo

jiamings / wgan Goto Github PK

View Code? Open in Web Editor NEW
237.0 10.0 82.0 11 KB

Tensorflow Implementation of Wasserstein GAN (and Improved version in wgan_v2)

Python 100.00%
tensorflow generative generative-adversarial-network generative-model tensorflow-models tensorflow-experiments

wgan's Introduction

Wasserstein GAN

Tensorflow implementation of Wasserstein GAN.

Two versions:

  • wgan.py: the original clipping method.
  • wgan_v2.py: the gradient penalty method. (Improved Training of Wasserstein GANs).

How to run (an example):

python wgan_v2.py --data mnist --model mlp --gpus 0

wgan's People

Contributors

jiamings avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

wgan's Issues

Weight clipping should occur AFTER critic update

Thanks again for providing this code. One thing I wanted to point out is that in the psuedo-code provided in the paper they clip the weights after gradient update for the critic. You would just need to move sess.run(d_clip) after the rms_prop update.

Batch Normalization

Thanks for sharing the code. The code is elegant and well structured, I'm going to pick this one as a starting point.
But there might be a glitch in the batch normalization.
Are you using batch normalization the in the same way during training and testing?
Or am I missing anything in the paper suggesting using batch normalization in this way?

If you're using BN in the normal sense, it seems that in your training the moving_mean and moving_variance is not updated.
Check this link: tensorflow/tensorflow#1122
And perhaps also this: http://r2rt.com/implementing-batch-normalization-in-tensorflow.html
A way to update the variables might be the following:

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            update_ops = tf.no_op()

Please let me know if I miss anything

Wrong objective

self.g_loss = tf.reduce_mean(self.d_)

I think according to the paper there is a minus missing in this line. the paper is optimising -D(G(z))

why the loss in wgan.py is different with the original paper?

It makes me confused that which one is correct?
As implemented in wgan.py, we have
self.g_loss = tf.reduce_mean(self.d_)
self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)
however, according to the original paper of wgan, it seems that we should minimize (-1)*self.g_loss, instead of self.g_loss. Could you tell me why the losses are implemented in the above form? Anyway, it seems that using the implementation in wgan.py or wgan_v2.py, I can still get some results. This makes me more confused.

How about the losses as follows
self.g_loss = tf.reduce_mean(tf.scalar_mul(-1,self.d_))
self.d_loss = tf.reduce_mean(self.d_) - tf.reduce_mean(self.d)
?

Thank you!

License question

Hey there, thanks a lot for this implementation of wgan in tensorflow. I was just wondering if it's actually free to use? I believe without an explicit license in the repo, the code defaults to being completely unusable by other people (this includes potential contributors). https://choosealicense.com/no-permission/

It's fine if this is actually your intent, but just making sure!

Thank you for posting this. I have some Tensorflow language candy to share.

Thank you for sharing your code. It actually helped me a lot!

for ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
you can actually use ddx = tf.norm(ddx, axis=1), I tried this it's actually the same result
for the Discriminator in mlp, I use tf.layers, that saves a lot of lines of code. As below.

def discriminator(x):
    with tf.variable_scope('discriminator'):
        nn_x  = tf.reshape(x, [tf.shape(x)[0], 28, 28, 1])
        conv1 = tf.layers.conv2d(nn_x, filters=64, kernel_size=4, strides=2, activation=leaky_relu)
        conv2 = tf.layers.conv2d(conv1, filters=128, kernel_size=4, strides=2, activation=leaky_relu)
        bn    = tf.layers.batch_normalization(conv2, training=True)
        flt   = tf.contrib.layers.flatten(bn)
        dense = tf.layers.dense(flt, 1024, activation=leaky_relu)
        logits= tf.layers.dense(dense, 1)
        return logits

Generator loss Interpretation

In the standard GAN's, the Generator is optimized in a way that D mistakes a generated sample as a real sample. And the Discriminator output (sigmoid activation) implies the probability of sample coming from Real(1) or generated(0) distribution.

However in WGAN's, losses are defined as: Dloss= D(real)-D(fake), Gloss= D(fake). We minimise both the loss in an alternate fashion with more iterations to critic.

What does minimising Gloss = D(fake) infers? Since the Output of critic does not imply to any physical quantity.. what does minimising Gloss do?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.