Coder Social home page Coder Social logo

vae_mnist_cifar's Introduction

Variational Auto Encoder trained for CIFAR and MNIST datasets

Trained a Variational Auto Encoder using pytorch lightning using CIFAR and MNIST data

Achievements

  • label appended to the Encoder imput
  • label hot encoded and resized to append
  • Created a batch of images with different labels (except the original to be sent for inference to trained model
  • Combined the input and output to be displayed on a grid and dumped into a output image with labels
  • Padding to achieve the desired size of inputs like in MNIST
  • Understanding the need of the ready model and Increasing the channel size to handle the same
  • Handling the display of 1 channel to 3 channels and vice versa

Outputs

VAE CIFAR 25 image outputs giving the non correct label as input along with the image to encoder

image

VAE CIFAR outputs along with the input image with correct label

image image image image image image

VAE CIFAR output less training

image

VAE MNIST 25 image outputs giving the non correct label as input along with the image to encoder

image

Files

  • VAE_CIFAR.py
    • Training file for VAE CIFAR with label also appended along with image
    • Hotencoded the label and then resized to the image dimensions and added
  • VAE_CIFAR_inference.py
    • Inference for CIFAR VAE trained using the earlier code
    • Pass a different label to see what we get as outcome for each of the images
    • Handled the multiple nuances with passing not correct labels and displaying the outcomes in a proper format with labels
  • VAE_MNIST.py
    • Training file for VAE MNIST with label also appended along with image - On same lines as VAE_CIFAR
    • Hotencoded the label and then resized to the image dimensions and added
    • Combined the inference part and generation of output images also internally in this
    • Additional Complexity to convert the MNIST to 3 channels as well as to 32 by padding due to the Resnet model used by pl_bolts implementation
    • Also comparison of the output with the input for loss calculation
    • End to show the outputs in a proper format with the changed dimensions
  • *.png
    • Various outputs

Hyperparameters used

  • CIFAR
    • Epochs=30
    • enc_out_dim=512
    • latent_dim=256
    • input_height=32
    • optimizer=torch.optim.Adam(self.parameters(), lr=1e-4)
    • batch_size=16
    • num_workers=16
  • MNIST
    • Epochs=50
    • enc_out_dim=512
    • latent_dim=256
    • input_height=32 (Adjusted size of the image)
    • optimizer=torch.optim.Adam(self.parameters(), lr=1e-4)
    • batch_size=16
    • num_workers=16

Hardware used

  • Google Colab T4
  • Some experimentation on CPU as well

Detailed Training detials

  • Important Excerpts from the tensorboard logs

Elbo CIFAR training

image

KL Loss CIFAR training

image

Recon loss CIFAR training

image

Reconstruction CIFAR training

image

Elbo MNIST training

image

KL Loss MNIST training

image

Recon loss MNIST training

image

Reconstruction MNIST training

image

Observations

  • Output with different labels do not differ
  • When using 28 x 28 image the output we were getting was 24 x 24 causing issues with comparison (So switched to 32 x 32 by padding)
  • Also facing issues due to single channel in MNIST as model was defined accordingly
    • Instead of changing the model copied the channels 3 times to create a 3 channel input

vae_mnist_cifar's People

Contributors

chintanshahds avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.