Coder Social home page Coder Social logo

draw's Introduction

draw

TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention Without Attention

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.

Usage

python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python myattn /tmp/draw/draw_data.npy

To run training without attention, do:

python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False

Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py and editing the path to your checkpoint file as needed. Save electricity!

saver.restore(sess, "/tmp/draw/drawmodel.ckpt")

This git repository contains the following pre-trained in the data/ folder:

Filename Description
draw_data_attn.npy Training outputs for DRAW with attention
drawmodel_attn.ckpt Saved weights for DRAW with attention
draw_data_noattn.npy Training outputs for DRAW without attention
drawmodel_noattn.ckpt Saved weights for DRAW without attention

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

Useful Resources

draw's People

Contributors

ericjang avatar guohengkai avatar iamgroot42 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

draw's Issues

kl divergence computation

Thanks for sharing this elegant DRAW model!

However, I found in kl divergence computation in draw.py, line 191, the last term should be 0.5 instead of 0.5*T according to the paper's equation 11.

Even though this constant term won't affect the optimization process, I think you may get a different but reasonable loss curve. Because in this situation, it is possible to get negative KL divergence.

Thanks!

Results could be better?

I have been playing with this code for a few days. I can reproduce the GIF animation showed in the first page of this repository. However, this other implementation based on theano (https://github.com/jbornschein/draw) achieves much (subjectively) nicer results (look at their GIF animation). I have tried to use their parameters (like T=64 and read_window=2) in the tensorflow code but I was unable to reproduce results that look that nice. Do you have any idea why there is such a difference and how we can achieve results like that using this tensorflow code?

By niceness I mean the animation looks more realistic, which probably means what the model learns is closer to the actual causal process that happens in human handwriting.

The noise sampling and the Lz loss curve issue

Ln-44:
e=tf.random_normal((batch_size,z_size), mean=0, stddev=1) # Qsampler noise
I think it should be placed into the funcion sampleQ. Or else the inference will fail.

However, when I made such a modification, the Lz loss-curve will increase instead of decline, as shown here:
image

Why?

mu_x typo?

Not issue, just out of curious

grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])    
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19

if N = 3, delta = 1
grid_i will be [0 1 2]
then grid_i - N / 2 is grid_i - 1. = [-1 0 1]
then grid_i - N / 2 - 0.5 is [-1.5 -0.5 0.5]

but I think [-1 0 1] is reasonable value, the mean location will be [gx-1, gx, gx+1]
why need to subtract 0.5, just follow the paper or I miss something?
thanks

Why is the Loss (L^x + L^z) so small?

It is about 70, lower than most reported results.

I also find one bug. I think you are misleading by the Eq. (12)., the equation is used to compute each element of vector z.

kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch)

should be

kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma-1,1) # each kl term is (1xminibatch)
# kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*z_size*.5 # alternatively

Or the kl term will blow up with large z_size.

Another issue is the mnist data in your code is not binarized. But it won't make much difference.

Filterbank has minor error

Hi Eric,

thanks for this implementation! Its been hugely useful in replicating the paper. One thing that I noticed is that in the filterbank function you're squaring the entire exponent rather than just the numerator which is what you want.

i/e your filters are slightly off from equations 24 and 25 in the paper.

Thanks!

Raza

Should x_hat also be used for prediction period?

It may not be a problem, but I am just curious about why x_hat (involving true data) is also used for prediction period. Because I think, after training, the model should produce data independently, not by means of true data.

Details as follows:
Read x as well as x_hat
x = filter_img(x, Fx, Fy, gamma, read_n) # batch x (read_n*read_n) x_hat = filter_img(x_hat, Fx, Fy, gamma, read_n)
After the training:
canvases = sess.run(cs, feed_dict) # generate some examples canvases = np.array(canvases) # T x batch x img_size

It seems that, x_hat is also fed into the model. But x_hat contains the true data.

Thanks!

Command to run program

FYI, README.md has a missing word:

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python plot_data.py myattn /tmp/draw/draw_data.npy

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.