beta-vae's Issues

How to prevent from KL loss collapse

Hello! I found your kl divergence curve is flat after about 100k iterations. When I train other VAE tasks, I found that the kl divergence is easier to converge than reconstruction loss. The training loss usually suffer from kl divergence collapse. Kl converges to a number close to zero. So I want to know how to keep KL divergence curve flat with beta>1?

Disentanglement metric


Haven't found it in your code. Is there d.metric in it somewhere?

Thanks in advance

Couldn't run well with TypeError

/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/ UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
Visdom python client failed to establish socket to get messages from the server. This feature is optional and can be disabled by initializing Visdom with use_incoming_socket=False, which will prevent waiting for this request to timeout.
=> no checkpoint found at 'checkpoints/main/last'
Traceback (most recent call last):
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/", line 69, in
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/", line 21, in main
net = Solver(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/", line 140, in init
self.data_loader = return_data(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/", line 80, in return_data
train_data = dset(**train_kwargs)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/", line 18, in init
super(CustomImageFolder, self).init(root, transform)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/", line 99, in init
classes, class_to_idx = find_classes(root)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/", line 24, in find_classes
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
TypeError: argument should be string, bytes or integer, not PosixPath

Training on dsprites fails with 0-dim tensor

I'm running ./ with a reduced number of iterations.
Error is below:

=> no checkpoint found at 'checkpoints/dsprites_B_gamma100_z10/last'
 67%|██████████████████████████████████████████████████                         | 10000/15000.0 [34:22<16:59,  4.91it/s]Traceback (most recent call last):
  File "", line 69, in <module>
  File "", line 24, in main
  File "/Users/rlee18/git/Beta-VAE/", line 182, in train
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

Tensor dimension bug in visualization code

I just started using your code to play around with beta-VAE and it's great. Unfortunately, there is one bug when I try to turn on visualization:

Traceback (most recent call last):
File "", line 65, in
File "", line 24, in main
File "/home/tony/Github/Beta-VAE/", line 178, in train
File "/home/tony/Github/Beta-VAE/", line 230, in viz_lines
klds =[dim_wise_klds, mean_klds, total_klds], 1).cpu()
RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1 at /pytorch/aten/src/TH/generic/THTensorMath.c:3577

Steps to reproduce:
python -m visdom.server

I guess this is quite easy to fix and it doesn't appear in an earlier version (commit 66fcd41), but I'm very new to Pytorch, so it's hard for me to find a fix.

Thanks again for the code and I hope this can be fixed easily,

P.S.: Would it be possible for you to add a license so that we can use your code to benchmark other models in our research (which might get published at some point)? That would be great!

Output of decoder 2 *z_dim


Could you please explain why the encoders output is z_dim * 2 and not just z_dim.

Thanks for the very clean repo.

TypeError: save_image() got an unexpected keyword argument 'filename'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/", line 708, in _send
return self._handle_post(
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/", line 677, in _handle_post
r =, data=data)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/", line 590, in post
return self.request('POST', url, data=data, json=json, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/", line 542, in request
resp = self.send(prep, **send_kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/", line 655, in send
r = adapter.send(request, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f23720d29d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Traceback (most recent call last):
File "", line 69, in
File "", line 24, in main
File "/home/ye/python_scripts/Beta-VAE-master/", line 201, in train
File "/home/ye/python_scripts/Beta-VAE-master/", line 417, in viz_traverse
TypeError: save_image() got an unexpected keyword argument 'filename'
1%|▏ | 10000/1500000.0 [03:54<9:41:25, 42.71it/s]

Loss curve

I train on a custom dataset, use model H, and set z dim to 256. When I trained 100000 steps, the loss was still as high as 400. Is this a normal phenomenon? Is the high loss caused by too large z dim?

here is my loss curve
2023-03-03 15-46-08屏幕截图

and this is my param
2023-03-03 15-58-32屏幕截图

Supporting image size larger than 64*64

Hello, I would like to know whether there is a specific reason for preventing image size from varying. if not so, do you know what kind of modifications would be required to support, say 512*512 images ?

about distanglement

Why is it that when beta is set to 4, the disentanglement effect is very poor, and the mig score is only 0.06. When it is set to 8, the effect is almost the same as others who set 4.

Need Help on Visdom

I rerun your implementation on my machine but it seems to not work for me. Can you please suggest how can I solve this

  1. I am going to begin with CelebA dataset
  2. I extracted the dataset to folder D:...
  3. I deleted most of the photos to make the size manageable to see if it can run (left with around 40 phtos)
  4. install visdom, torch, torchvision
  5. initialize visdom (python -m visdom.server)
    Step 5
  6. Use my browser to access http://localhost:8097/
    Step 6   Output
  7. python --dataset celeba --seed 1 --lr 1e-4 --beta1 0.9 --beta2 0.999 --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1.5e3 --beta 10 --viz_name celeba_H_beta10_z10
    Step 7

Even if there is no error shown, I there is no output shown on either http://localhost:8097/ or in the folder D:...\Beta-VAE-master\outputs\celeba_H_beta10_z10
Step 6   Output

About the quality of recon_img

Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable

def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(
return mu + std * eps

def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
if m.bias is not None:
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
if m.bias is not None:

class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size

def forward(self, tensor):
    return tensor.view(self.size)

class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

def __init__(self, z_dim=10, nc=3):
    super(BetaVAE_H, self).__init__()
    self.z_dim = z_dim = nc
    self.encoder = nn.Sequential(
        nn.Conv2d(nc, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.Conv2d(64, 64, 4, 2, 1),  # B,  64,  4,  4
        nn.Conv2d(64, 256, 4, 1),  # B, 256,  1,  1
        View((-1, 256 * 1 * 1)),  # B, 256
        nn.Linear(256, z_dim * 2),  # B, z_dim*2
    self.decoder = nn.Sequential(
        nn.Linear(z_dim, 256),  # B, 256
        View((-1, 256, 1, 1)),  # B, 256,  1,  1
        nn.ConvTranspose2d(256, 64, 4),  # B,  64,  4,  4
        nn.ConvTranspose2d(64, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.ConvTranspose2d(64, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.ConvTranspose2d(32, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64


def weight_init(self):
    for block in self._modules:
        for m in self._modules[block]:

def forward(self, x):
    distributions = self._encode(x)
    mu = distributions[:, :self.z_dim]
    logvar = distributions[:, self.z_dim:]
    z = reparametrize(mu, logvar)
    x_recon = self._decode(z)

    return x_recon, mu, logvar

def _encode(self, x):
    return self.encoder(x)

def _decode(self, z):
    return self.decoder(z)

import torch
from torch import optim
from import DataLoader

from beta_vae import BetaVAE_H
import torch.nn.functional as F
from torchvision import datasets, transforms

def recon_loss(x, x_recon):
x_recon = F.sigmoid(x_recon)
rec_loss = F.mse_loss(x_recon, x)
return rec_loss

def kld_loss(mu, logvar):
if == 4:
mu = mu.view(mu.size(0), mu.size(1))
if == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))

klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
return total_kld

def train(epochs=1000, batch_size=128, z_dim=32, device='cuda:2', lr=1e-4, beta=10):
dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True, transform=transforms.Compose([
]), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = BetaVAE_H(z_dim=z_dim, nc=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    loss, r_loss, k_loss = (0, 0, 0)
    for idx, (images, _) in enumerate(dataloader):
        images =
        x_recon, mu, logvar = model(images)
        rec_loss = recon_loss(images, x_recon)
        kl_loss = kld_loss(mu, logvar)

        beta_vae_loss = rec_loss + beta * kl_loss
        loss += beta_vae_loss.item()
        r_loss += rec_loss.item()
        k_loss += kl_loss.item()

` Is is a valid function?

While reading this code,, I found in line 431 and line 433 both functions and are undefined functions, neither defined in or pytorch API. However, no error came out when I did the training. Is there anyone who can enlighten me?
Thanks a bunch

def net_mode(self, train):
        if not isinstance(train, bool):
            raise('Only bool type is supported. True or False')

        if train:

How to tune hyper-parameters

Hi WonKwang,

First, thanks for the great implementation for beta-vae.

Is there any method or intuition to choose these beta, gamma, C_max hyper-parameters?


