First of all, thank you for making this tutorial. I learned a lot from it. There is one thing I would like to double check.
In cell 5 of your Jupyter notebook, you tried to calculate all the possible distance pairs between the continuous latent space x_flat
and the codebook embedding
as
torch.addmm(torch.sum(embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1),
x_flat, embedding.t(),
alpha=-2.0, beta=1)
However, I don't think it is the correct way to calculate the distance between $z_e(x)_i$ and $e_j$ for $i \neq j$.
Let's look at the following example.
For simplicity, let's assume there are only two entries $e_1$ and $e_2$ in the codebook.
Let's set $e_1=(-0.8567, 1.1006, -1.0712)$ and $e_2=(0.1227, -0.5663, 0.3731)$.
Similarly, let's set $z_e(x)_1 = (0.4033, 0.8380, -0.7193)$ and $z_e(x)_2 = (-0.4033, -0.5966, 0.1820)$.
import torch
embedding = torch.tensor(
[[-0.8567, 1.1006, -1.0712],
[ 0.1227, -0.5663, 0.3731]]
)
x_flat = torch.tensor(
[[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820]]
)
dist_kang = torch.addmm(torch.sum(embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1),
x_flat, embedding.t(),
alpha=-2.0, beta=1)
dist_lucidrain = (-torch.cdist(x_flat, embedding, p=2)) ** 2
Your implementation (dist_kang
) returns
tensor([[1.7804, 2.4136],
[5.4872, 0.3141]])
While the correct implementation (dist_lucidrain
from lucidrain) returns
tensor([[1.7804, 3.2441],
[4.6566, 0.3141]])
As you can see, even though your implementation returns the correct results for the $z_e(x)_i$ and $e_j$ pairs when $i = j$ (the diagonal elements), the results for the $i \neq j$ cases are all off.
Here is an example on how to verify the correct values for the two cases:
# e1-z1 and e2-z2 distance pair
torch.sum(embedding**2 - 2*embedding*x_flat + x_flat**2, dim=1)
>>> tensor([1.7804, 0.3141])
# e2-z1 and e1-z2 distance pair
torch.sum(embedding.flip(0)**2 - 2*embedding.flip(0)*x_flat + x_flat**2, dim=1)
>>> tensor([3.2441, 4.6566])
So the correct value for the $e_2$, $z_e(x)_1$ pair should be 3.2441
, and $e_1$, $z_e(x)_2$ should be 4.6566
.
And therefore, torch.cdist
should be used to calculate the distance instead ofr torch.addmm
.