Coder Social home page Coder Social logo

pytorch-vae-tutorial's Introduction

Minsu Jackson Kang

AI researcher in NCSOFT

Research interests

Neural speech synthesis, Deep music recognition/generation, Style-transfer, Representation-learning, Deep-generative models, Deep-learning, Machine-learning

Contact

Please feel free to email me. Always open to discuss about AI industry, AI research field and other topics.

Profile views

pytorch-vae-tutorial's People

Contributors

jackson-kang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch-vae-tutorial's Issues

show image: Dead kernel

Hey Jackson, thanks for this.
Very nice.

I'm wondering why this kills my kernel "Dead kernel" after this

show_image(x, idx=0)

it was working just fine until this cell killed it.

model.eval()

with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)
        
        x_hat, _, _ = model(x)


        break

So, I tested every line and more and more of it, no trouble until the entire cell worked just fine.
However, now ,2 cells down kills it:
show_image(x, idx=0)
I'm really confused about this.

Setup: PC, CPU (no GPU, so cuda=False), Anaconda, Chrome, Python 3.9.7, Windows 10 Enterprise, 64 bit, RAM= 32.0 GB.
I'm running the Jupyter notebook through the organisation's servers on this PC.

Thanks.

Wrong generative process

Firstly, thank you for the tutorial!

For your generative process, you are "Generating image from noise vector" even though you have already learnt the latent space. This is actually the problem with vanilla autoencoders. The latent space is not learnt so when you just randomly pick a latent vector, chances are it is not within the latent space/one that the decoder doesn't know how to work with.

To do it right, you should 1) sample an image 2) make a forward pass with that image to get a mean and covariance 3) use aforementioned mean and covariance to sample a latent vector

Mathemathically speaking, (hopefully the notation is universal) you are sampling x and using the probabilistic encoder p(z|x) where the given x is the one you just sampled to sample a latent variable. Then that's the latent variable within the latent space which the decoder knows how to work with. And of course, the generated image won't be from the training dataset.

I figured you may have already figured something was wrong considering your generated results so I hope this clear things up.

with torch.no_grad():
    x, _ = next(iter(train_loader))  
    x = x[0]    
    sampled_x = x.view(1, x_dim)
    _, mean, log_var = model.forward(sampled_x)
    var      = torch.exp(0.5*log_var)   
    epsilon = torch.rand_like(var)
    sampled_z = mean + var*epsilon 
    generated_images = decoder(sampled_z)

This is some code I used with your notebook to do it the proper way. Hope it makes sense. Thanks again for the tutorial.

Question about the distance calculation inside VQEmbeddingEMA

First of all, thank you for making this tutorial. I learned a lot from it. There is one thing I would like to double check.

In cell 5 of your Jupyter notebook, you tried to calculate all the possible distance pairs between the continuous latent space x_flat and the codebook embedding as

torch.addmm(torch.sum(embedding ** 2, dim=1) +
    torch.sum(x_flat ** 2, dim=1),
    x_flat, embedding.t(),
    alpha=-2.0, beta=1)

However, I don't think it is the correct way to calculate the distance between $z_e(x)_i$ and $e_j$ for $i \neq j$.
Let's look at the following example.

For simplicity, let's assume there are only two entries $e_1$ and $e_2$ in the codebook.
Let's set $e_1=(-0.8567, 1.1006, -1.0712)$ and $e_2=(0.1227, -0.5663, 0.3731)$.
Similarly, let's set $z_e(x)_1 = (0.4033, 0.8380, -0.7193)$ and $z_e(x)_2 = (-0.4033, -0.5966, 0.1820)$.

import torch

embedding = torch.tensor(
    [[-0.8567,  1.1006, -1.0712],
     [ 0.1227, -0.5663,  0.3731]]
    )
x_flat = torch.tensor(
    [[ 0.4033,  0.8380, -0.7193],
     [-0.4033, -0.5966,  0.1820]]
    )

dist_kang = torch.addmm(torch.sum(embedding ** 2, dim=1) +
                     torch.sum(x_flat ** 2, dim=1),
                     x_flat, embedding.t(),
                     alpha=-2.0, beta=1)

dist_lucidrain = (-torch.cdist(x_flat, embedding, p=2)) ** 2

Your implementation (dist_kang) returns

tensor([[1.7804, 2.4136],
        [5.4872, 0.3141]])

While the correct implementation (dist_lucidrain from lucidrain) returns

tensor([[1.7804, 3.2441],
        [4.6566, 0.3141]])

As you can see, even though your implementation returns the correct results for the $z_e(x)_i$ and $e_j$ pairs when $i = j$ (the diagonal elements), the results for the $i \neq j$ cases are all off.
vqvae

Here is an example on how to verify the correct values for the two cases:

# e1-z1 and e2-z2 distance pair
torch.sum(embedding**2 - 2*embedding*x_flat + x_flat**2, dim=1)

>>> tensor([1.7804, 0.3141])
# e2-z1 and e1-z2 distance pair
torch.sum(embedding.flip(0)**2 - 2*embedding.flip(0)*x_flat + x_flat**2, dim=1)

>>> tensor([3.2441, 4.6566])

So the correct value for the $e_2$, $z_e(x)_1$ pair should be 3.2441, and $e_1$, $z_e(x)_2$ should be 4.6566.
And therefore, torch.cdist should be used to calculate the distance instead ofr torch.addmm.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.