Coder Social home page Coder Social logo

jovana-gentic / vae_celeba Goto Github PK

View Code? Open in Web Editor NEW
1.0 1.0 0.0 4.89 MB

Simple VAE implementation in tensorflow, jax and pytorch where both the encoder and decoder model use gaussian distributions.

Jupyter Notebook 100.00%
celeba jax pytorch tensorflow vae vae-cnn vae-pytorch vae-tensorflow vae-jax

vae_celeba's Introduction

VAE_celeba

Jovana Gentić 🦆


In this notebook, we implemented a VAE where both the encoder and decoder model gaussian distributions. The model is trained on CelebA_10 64x64 images. Model is trained in tensorflow and supports multi-GPU. We created jax and pytorch versions of code for learning purposes.

Images before and after cropping and resizing for model training

About the model

Encoder is made of convolutions that downsample the image resolution until a certain point, after which we flatten the image and use a stack of dense layers to get the posterior distribution q(z|x).

Decoder starts off with dense layers to process the sample z, followed by an unflatten (reshape) operation into an activation of shape (B, h, w, C). The activation is then upsampled back to the original image size using a stack of resize-conv blocks. Resize-conv block is a simple nearest neighbord upsampling + convolutions, used to upsample images instead of deconvolution layers. This block is useful to avoid checkerboard artifacts: https://distill.pub/2016/deconv-checkerboard/

For the Loss, we use the Negative ELBO = -likelihood + KL_div.

  • likelihood = decoder_dist.log_pdf(targets)
  • KL_div = KL(posterior_dist || prior_dist)
  • The posterior_dist is the encoder distribution.
  • For simplicity, we set the prior distribution to be a simple standard Gaussian N(0, 1).

To help the model avoid a posterior collapse, we warmup the KL_div by linearly scaling it up over 10000 steps.

Generate

Pick prior distribution temperature (z_temp) and decoder distribution temperature (x_temp) to generate new images from prior distribution, pictures = model.generate(z_temp=1., x_temp=0.3)

z_temp: float, defines the temperature multiplier of the encoder stddev. Smaller z_temp makes the generated samples less diverse and more generic

x_temp: float, defines the temperature multiplier of the decoder stddev. Smaller x_temp makes the generated samples smoother, and loses small degree of information.

vae_celeba's People

Contributors

jovana-gentic avatar

Stargazers

Rayhane Mama avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.