Coder Social home page Coder Social logo

cvae's Introduction

Conditional Variational Auto-encoder

PWC

Introduction

This tutorial implements Learning Structured Output Representation using Deep Conditional Generative Models paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.

Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).

The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.

The problem

Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:

image1

Our objective is to learn a model that can perform probabilistic inference and make diverse predictions from a single input. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five.

Evaluating the results

For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.

image1

We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table.

1 quadrant 2 quadrants 3 quadrants
NN (baseline) 100.4 61.9 25.4
CVAE (Monte Carlo) 71.8 51.0 24.2
Performance gap 28.6 10.9 1.2

We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: the estimated CLLs of the CVAE significantly outperforms the baseline NN.

See the full code on Github.

References

[1] Learning Structured Output Representation using Deep Conditional Generative Models,
     Kihyuk Sohn, Xinchen Yan, Honglak Lee

cvae's People

Contributors

ucals avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.