Coder Social home page Coder Social logo

Comments (2)

mjdall avatar mjdall commented on May 21, 2024 8

Hi! I've been working with B-VAE and the controlled capacity B-VAE (the one you're referencing) quite a bit lately, and I found myself asking this exact question.

I found this repository especially helpful in figuring out the parameters. Long story short: reconstruction loss is summed across each image and the sum is then averaged across your batch. The implementation in this repository computes the pixel-wise average which doesn't work well with the intended parameters. Modifying the function results in a scale that algins properly with the Understanding disentangling in β-VAE, Burgess et al., arxiv:1804.03599, 2018 paper. And the default values will have a much better time working out of the box.

Modified function:

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        self.num_iter += 1
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

       # original: recons_loss = F.mse_loss(recons, input)
       # modified:
        recons_loss = F.mse_loss(recons, input, reduction="sum").div(input.shape[0])

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
            loss = recons_loss + self.beta * kld_weight * kld_loss
        elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
            self.C_max = self.C_max.to(input.device)
            C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
            loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
        else:
            raise ValueError('Undefined loss type.')

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

Finally, as for good gamma and max_capacity values: Depending on your application, you can get better reconstruction by increasing the capacity available to the model but you'll also end up with a latent space that has less regularization which may be undesirable for what you're aiming for. KLD Capacity is summed across the latent dimension, so if you've got a latent space of size 64, and your capacity was 32, each of your latent distributions would be able to diverge (on average) by 0.5 each. Or a few dimensions may diverge greatly to take up the capacity available. I'd recommend starting with values from 1/8 to 2x your latent size and visualizing how your losses differ and what effect it has on your reconstructions.

Gamma on the other hand is dataset dependent. You can run an experiment with a value at i.e. 1000 and work out what % of your loss corresponds to the kld capacity component and tune accordingly. If you're working with 3 channeled images that have been scaled correctly to suit the model input, then 1000 -> 10000 may be a good starting range.

Very verbose reply but I hope that helps.

from pytorch-vae.

ranabanik avatar ranabanik commented on May 21, 2024

@mjdall Could you explain the self.loss_type == 'B' loss and what are the clamp, C_max and C_stop_iter parameters are about from the paper?
On another note taking random input and recon,
The modified recon you shared generates loss of ~2200 whereas the KLD is around ~4. Even if the $\beta$ = 25 which is not sufficient to balance the high valued reconstruction loss compared to low KLD.
Any thoughts on that?

from pytorch-vae.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.