Comments (2)
Hi! I've been working with B-VAE and the controlled capacity B-VAE (the one you're referencing) quite a bit lately, and I found myself asking this exact question.
I found this repository especially helpful in figuring out the parameters. Long story short: reconstruction loss is summed across each image and the sum is then averaged across your batch. The implementation in this repository computes the pixel-wise average which doesn't work well with the intended parameters. Modifying the function results in a scale that algins properly with the Understanding disentangling in β-VAE, Burgess et al., arxiv:1804.03599, 2018 paper. And the default values will have a much better time working out of the box.
Modified function:
def loss_function(self,
*args,
**kwargs) -> dict:
self.num_iter += 1
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
# original: recons_loss = F.mse_loss(recons, input)
# modified:
recons_loss = F.mse_loss(recons, input, reduction="sum").div(input.shape[0])
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
loss = recons_loss + self.beta * kld_weight * kld_loss
elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
self.C_max = self.C_max.to(input.device)
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
else:
raise ValueError('Undefined loss type.')
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
Finally, as for good gamma and max_capacity values: Depending on your application, you can get better reconstruction by increasing the capacity available to the model but you'll also end up with a latent space that has less regularization which may be undesirable for what you're aiming for. KLD Capacity is summed across the latent dimension, so if you've got a latent space of size 64, and your capacity was 32, each of your latent distributions would be able to diverge (on average) by 0.5 each. Or a few dimensions may diverge greatly to take up the capacity available. I'd recommend starting with values from 1/8 to 2x your latent size and visualizing how your losses differ and what effect it has on your reconstructions.
Gamma on the other hand is dataset dependent. You can run an experiment with a value at i.e. 1000 and work out what % of your loss corresponds to the kld capacity component and tune accordingly. If you're working with 3 channeled images that have been scaled correctly to suit the model input, then 1000 -> 10000 may be a good starting range.
Very verbose reply but I hope that helps.
from pytorch-vae.
@mjdall Could you explain the self.loss_type == 'B'
loss and what are the clamp, C_max
and C_stop_iter
parameters are about from the paper?
On another note taking random input and recon,
The modified recon you shared generates loss of ~2200 whereas the KLD is around ~4. Even if the
Any thoughts on that?
from pytorch-vae.
Related Issues (20)
- about how to test a image for vae HOT 1
- How to tansfer "nccl" to "gloo"? HOT 2
- Stuck at Validation sanity check: 0it [00:00, ?it/s]
- FileNotFoundError: [Errno 2] No such file or directory: 'Data/celeba/list_eval_partition.txt' HOT 1
- package versions - reproductibility HOT 1
- VSCode for Mac
- Custom dataset HOT 1
- Temperature setting in CAT-VAE model
- AttributeError: 'VAEDataset' object has no attribute '_has_setup_TrainerFn.FITTING HOT 1
- MisconfigurationException HOT 1
- Problems when using custom dataset HOT 8
- value error HOT 2
- [W socket.cpp:663] [c10d] The client socket has failed to connect to [DESKTOP-H2DRQRJ]:62468 (system error: 10049) HOT 2
- Question about Reconstructing Data and Interpreting Latent Space HOT 1
- About VQ-VAE
- The reconstructed image looks okay, but the sampling results are very poor HOT 1
- Wrong order of channel numbers for encoder vs. decoder? HOT 1
- Mistake in experiment.py file
- f'The provided lr scheduler "{scheduler}" is invalid' HOT 2
- Problem about KL loss when training CVAE model.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-vae.