Coder Social home page Coder Social logo

guoyang-xie / autoencoding-beyond-pixels-using-a-learned-similarity-metric Goto Github PK

View Code? Open in Web Editor NEW

This project forked from prateekmunjal/autoencoding-beyond-pixels-using-a-learned-similarity-metric

0.0 0.0 0.0 20.88 MB

A tensorflow implementation of VAE-GAN. This is the first approach which viewed the discriminator as a loss function to improve.

Python 100.00%

autoencoding-beyond-pixels-using-a-learned-similarity-metric's Introduction

VAE/GAN

A Tensorflow implementation of VAE-GAN, following the paper: VAE/GAN. The encoder and decoder functions are implemented using fully strided convoluttional layers and transposed convolution layers respectively. The discriminator network has the same architecture as that of encoder with an additional last one layer of its output. As suggested by authors I have implemented Gaussian decoders and Gaussian prior.

Setup

  • Python 3.5+
  • Tensorflow 1.9

Relevant Code Files

File config.py contains the hyper-parameters for VAE/GAN reported results.

File vae-gan.py contains the code to train VAE/GAN model.

Similarly, as the name suggests, file vae-gan_inference.py contains the code to test the trained VAE/GAN model.

Usage

Training a model

NOTE: For celebA, make sure you have the downloaded dataset from here and keep it in the current directory of project.

python vae-gan.py

Test a trained model

First place the model weights in model_directory (mentioned in vae-gan_inference.py) and then:

python vae-gan_inference.py 

Emprical Observations

  • I observed that sometimes the presence of KL-divergence term in the loss of encoder network makes the model training cumbersome.

  • The only hyper-parameter I tweaked to alleviate the above issue is weight mutiplied to this KL term. Almost always, the KL-term weight equal to 1/batch_size works.

  • Another alternate I tried for Kl weight was taking as a function of epoch i.e sigmoid(epoch).

  • Intuitively, the dynamic Kl weight made more sense as with increasing epochs we increased the weight, therefore the model does not pay attention to KL divergence term in initial iterations. However, one should ask why do we want the model to not focus in initial iterations?

  • The reason is that we free the latent space variables in initial iterations to make them learn, meaningful representations responsible for reconstructing the input and with increasing epochs we make the latent distribution close to our prior as we increase KL term weight with epochs.

  • But why did not we used some other function like exp(epochs)? -- It is also a monotonic function.

  • While increasing the weight of KL term, we should have some limit else the model may completely focus on this term. Therefore we choose a function which has a saturation on large values of input.

Model weights

The weights for presented results in this repository are mentioned below which essentially are shared on google drive.

Generations

MNIST Celeb-A

Reconstructions

  • For MNIST dataset

    • At epoch: 1
    MNIST Original MNIST Reconstruction
    • At epoch: 50
    MNIST Original MNIST Reconstruction
  • For CelebA dataset

    • At epoch: 1
    Celeb-A Original Celeb-A Reconstruction
    • At epoch: 15
    Celeb-A Original Celeb-A Reconstruction

autoencoding-beyond-pixels-using-a-learned-similarity-metric's People

Contributors

prateekmunjal avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.