1konny / beta-vae Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of β-VAE
License: MIT License
Pytorch implementation of β-VAE
License: MIT License
Hello! I found your kl divergence curve is flat after about 100k iterations. When I train other VAE tasks, I found that the kl divergence is easier to converge than reconstruction loss. The training loss usually suffer from kl divergence collapse. Kl converges to a number close to zero. So I want to know how to keep KL divergence curve flat with beta>1?
Hello.
Haven't found it in your code. Is there d.metric in it somewhere?
Thanks in advance
/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/model.py:150: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
init.kaiming_normal(m.weight)
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
unorderable types: float() > NoneType()
Visdom python client failed to establish socket to get messages from the server. This feature is optional and can be disabled by initializing Visdom with use_incoming_socket=False
, which will prevent waiting for this request to timeout.
=> no checkpoint found at 'checkpoints/main/last'
Traceback (most recent call last):
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/main.py", line 69, in
main(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/main.py", line 21, in main
net = Solver(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/solver.py", line 140, in init
self.data_loader = return_data(args)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/dataset.py", line 80, in return_data
train_data = dset(**train_kwargs)
File "/home/hanzy/PyPro/GANandVAE/Beta-VAE-master/dataset.py", line 18, in init
super(CustomImageFolder, self).init(root, transform)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 99, in init
classes, class_to_idx = find_classes(root)
File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 24, in find_classes
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
TypeError: argument should be string, bytes or integer, not PosixPath
https://github.com/1Konny/Beta-VAE/blob/master/main.py#L26
traverse( ) is not a member of net
I'm running ./run_dsprites_B_gamma100_z10.sh with a reduced number of iterations.
Error is below:
=> no checkpoint found at 'checkpoints/dsprites_B_gamma100_z10/last'
67%|██████████████████████████████████████████████████ | 10000/15000.0 [34:22<16:59, 4.91it/s]Traceback (most recent call last):
File "main.py", line 69, in <module>
main(args)
File "main.py", line 24, in main
net.train()
File "/Users/rlee18/git/Beta-VAE/solver.py", line 182, in train
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
Hi
I just started using your code to play around with beta-VAE and it's great. Unfortunately, there is one bug when I try to turn on visualization:
Traceback (most recent call last):
File "main.py", line 65, in
main(args)
File "main.py", line 24, in main
net.train()
File "/home/tony/Github/Beta-VAE/solver.py", line 178, in train
self.viz_lines()
File "/home/tony/Github/Beta-VAE/solver.py", line 230, in viz_lines
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1 at /pytorch/aten/src/TH/generic/THTensorMath.c:3577
Steps to reproduce:
python -m visdom.server
sh run_dsprites_B_gamma100_z10.sh
I guess this is quite easy to fix and it doesn't appear in an earlier version (commit 66fcd41), but I'm very new to Pytorch, so it's hard for me to find a fix.
Thanks again for the code and I hope this can be fixed easily,
Tony
P.S.: Would it be possible for you to add a license so that we can use your code to benchmark other models in our research (which might get published at some point)? That would be great!
Hi,
Could you please explain why the encoders output is z_dim * 2 and not just z_dim.
Thanks for the very clean repo.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/init.py", line 708, in _send
return self._handle_post(
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/visdom/init.py", line 677, in _handle_post
r = self.session.post(url, data=data)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 590, in post
return self.request('POST', url, data=data, json=json, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 542, in request
resp = self.send(prep, **send_kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/sessions.py", line 655, in send
r = adapter.send(request, **kwargs)
File "/home/ye/anaconda3/envs/Pytorch/lib/python3.8/site-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8097): Max retries exceeded with url: /events (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f23720d29d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Traceback (most recent call last):
File "main.py", line 69, in
main(args)
File "main.py", line 24, in main
net.train()
File "/home/ye/python_scripts/Beta-VAE-master/solver.py", line 201, in train
self.viz_traverse()
File "/home/ye/python_scripts/Beta-VAE-master/solver.py", line 417, in viz_traverse
save_image(tensor=gifs[i][j].cpu(),
TypeError: save_image() got an unexpected keyword argument 'filename'
1%|▏ | 10000/1500000.0 [03:54<9:41:25, 42.71it/s]
Hello, I would like to know whether there is a specific reason for preventing image size from varying. if not so, do you know what kind of modifications would be required to support, say 512*512 images ?
Hi, I think line 309 in solver.py sholud be Y=mus
Why is it that when beta is set to 4, the disentanglement effect is very poor, and the mig score is only 0.06. When it is set to 8, the effect is almost the same as others who set 4.
Hi,
I rerun your implementation on my machine but it seems to not work for me. Can you please suggest how can I solve this
Even if there is no error shown, I there is no output shown on either http://localhost:8097/ or in the folder D:...\Beta-VAE-master\outputs\celeba_H_beta10_z10
Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size
def forward(self, tensor):
return tensor.view(self.size)
class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
def __init__(self, z_dim=10, nc=3):
super(BetaVAE_H, self).__init__()
self.z_dim = z_dim
self.nc = nc
self.encoder = nn.Sequential(
nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8
nn.ReLU(True),
nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4
nn.ReLU(True),
nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1
nn.ReLU(True),
View((-1, 256 * 1 * 1)), # B, 256
nn.Linear(256, z_dim * 2), # B, z_dim*2
)
self.decoder = nn.Sequential(
nn.Linear(z_dim, 256), # B, 256
View((-1, 256, 1, 1)), # B, 256, 1, 1
nn.ReLU(True),
nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4
nn.ReLU(True),
nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16
nn.ReLU(True),
nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32
nn.ReLU(True),
nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64
)
self.weight_init()
def weight_init(self):
for block in self._modules:
for m in self._modules[block]:
kaiming_init(m)
def forward(self, x):
distributions = self._encode(x)
mu = distributions[:, :self.z_dim]
logvar = distributions[:, self.z_dim:]
z = reparametrize(mu, logvar)
x_recon = self._decode(z)
return x_recon, mu, logvar
def _encode(self, x):
return self.encoder(x)
def _decode(self, z):
return self.decoder(z)
import torch
from torch import optim
from torch.utils.data import DataLoader
from beta_vae import BetaVAE_H
import torch.nn.functional as F
from torchvision import datasets, transforms
def recon_loss(x, x_recon):
x_recon = F.sigmoid(x_recon)
rec_loss = F.mse_loss(x_recon, x)
return rec_loss
def kld_loss(mu, logvar):
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
return total_kld
def train(epochs=1000, batch_size=128, z_dim=32, device='cuda:2', lr=1e-4, beta=10):
dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True, transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor()
]), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = BetaVAE_H(z_dim=z_dim, nc=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
loss, r_loss, k_loss = (0, 0, 0)
for idx, (images, _) in enumerate(dataloader):
images = images.to(device)
x_recon, mu, logvar = model(images)
rec_loss = recon_loss(images, x_recon)
kl_loss = kld_loss(mu, logvar)
beta_vae_loss = rec_loss + beta * kl_loss
optimizer.zero_grad()
beta_vae_loss.backward()
loss += beta_vae_loss.item()
r_loss += rec_loss.item()
k_loss += kl_loss.item()
optimizer.step()
`
Hi, minor mistake, in some .sh "beta2" (Adam) appears as "beta". This might create some confusion.
While reading this code, solver.py, I found in line 431 and line 433 both functions self.net.train() and self.net.eval() are undefined functions, neither defined in model.py or pytorch API. However, no error came out when I did the training. Is there anyone who can enlighten me?
Thanks a bunch
def net_mode(self, train):
if not isinstance(train, bool):
raise('Only bool type is supported. True or False')
if train:
**self.net.train()**
else:
**self.net.eval()**
Hi WonKwang,
First, thanks for the great implementation for beta-vae.
Is there any method or intuition to choose these beta, gamma, C_max hyper-parameters?
Thanks
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.