Coder Social home page Coder Social logo

yanndubs / disentangling-vae Goto Github PK

View Code? Open in Web Editor NEW
784.0 14.0 145.0 227.21 MB

Experiments for understanding disentanglement in VAE latent representations

License: Other

Python 96.47% Shell 3.53%
beta-vae factor-vae vae variational-autoencoder unsupervised-learning celeba dsprites beta-tcvae disentangled-representations chairs-dataset

disentangling-vae's Introduction

Disentangled VAE License: MIT Python 3.6+

This repository contains code (training / metrics / plotting) to investigate disentangling in VAE as well as compare 5 different losses (summary of the differences) using a single architecture:

Notes:

  • Tested for python >= 3.6
  • Tested for CPU and GPU

Table of Contents:

  1. Install
  2. Run
  3. Plot
  4. Data
  5. Our Contributions
  6. Losses Explanation
  7. Citing

Install

# clone repo
pip install -r requirements.txt

Run

Use python main.py <model-name> <param> to train and/or evaluate a model. For example:

python main.py btcvae_celeba_mini -d celeba -l btcvae --lr 0.001 -b 256 -e 5

You can run predefined experiments and hyper-parameters using -x <experiment>. Those hyperparameters are found in hyperparam.ini. Pretrained models for each experiment can be found in results/<experiment> (created using ./bin/train_all.sh).

Output

This will create a directory results/<saving-name>/ which will contain:

  • model.pt: The model at the end of training.
  • model-i.pt: Model checkpoint after i iterations. By default saves every 10.
  • specs.json: The parameters used to run the program (default and modified with CLI).
  • training.gif: GIF of latent traversals of the latent dimensions Z at every epoch of training.
  • train_losses.log: All (sub-)losses computed during training.
  • test_losses.log: All (sub-)losses computed at the end of training with the model in evaluate mode (no sampling).
  • metrics.log: Mutual Information Gap metric and Axis Alignment Metric. Only if --is-metric (slow).

Help

usage: main.py ...

PyTorch implementation and evaluation of disentangled Variational AutoEncoders
and metrics.

optional arguments:
  -h, --help            show this help message and exit

General options:
  name                  Name of the model for storing or loading purposes.
  -L, --log-level {CRITICAL,ERROR,WARNING,INFO,DEBUG,NOTSET}
                        Logging levels. (default: info)
  --no-progress-bar     Disables progress bar. (default: False)
  --no-cuda             Disables CUDA training, even when have one. (default:
                        False)
  -s, --seed SEED       Random seed. Can be `None` for stochastic behavior.
                        (default: 1234)

Training specific options:
  --checkpoint-every CHECKPOINT_EVERY
                        Save a checkpoint of the trained model every n epoch.
                        (default: 30)
  -d, --dataset {mnist,fashion,dsprites,celeba,chairs}
                        Path to training data. (default: mnist)
  -x, --experiment {custom,debug,best_celeba,VAE_mnist,VAE_fashion,VAE_dsprites,VAE_celeba,VAE_chairs,betaH_mnist,betaH_fashion,betaH_dsprites,betaH_celeba,betaH_chairs,betaB_mnist,betaB_fashion,betaB_dsprites,betaB_celeba,betaB_chairs,factor_mnist,factor_fashion,factor_dsprites,factor_celeba,factor_chairs,btcvae_mnist,btcvae_fashion,btcvae_dsprites,btcvae_celeba,btcvae_chairs}
                        Predefined experiments to run. If not `custom` this
                        will overwrite some other arguments. (default: custom)
  -e, --epochs EPOCHS   Maximum number of epochs to run for. (default: 100)
  -b, --batch-size BATCH_SIZE
                        Batch size for training. (default: 64)
  --lr LR               Learning rate. (default: 0.0005)

Model specfic options:
  -m, --model-type {Burgess}
                        Type of encoder and decoder to use. (default: Burgess)
  -z, --latent-dim LATENT_DIM
                        Dimension of the latent variable. (default: 10)
  -l, --loss {VAE,betaH,betaB,factor,btcvae}
                        Type of VAE loss function to use. (default: betaB)
  -r, --rec-dist {bernoulli,laplace,gaussian}
                        Form of the likelihood ot use for each pixel.
                        (default: bernoulli)
  -a, --reg-anneal REG_ANNEAL
                        Number of annealing steps where gradually adding the
                        regularisation. What is annealed is specific to each
                        loss. (default: 0)

BetaH specific parameters:
  --betaH-B BETAH_B     Weight of the KL (beta in the paper). (default: 4)

BetaB specific parameters:
  --betaB-initC BETAB_INITC
                        Starting annealed capacity. (default: 0)
  --betaB-finC BETAB_FINC
                        Final annealed capacity. (default: 25)
  --betaB-G BETAB_G     Weight of the KL divergence term (gamma in the paper).
                        (default: 1000)

factor VAE specific parameters:
  --factor-G FACTOR_G   Weight of the TC term (gamma in the paper). (default:
                        6)
  --lr-disc LR_DISC     Learning rate of the discriminator. (default: 5e-05)

beta-tcvae specific parameters:
  --btcvae-A BTCVAE_A   Weight of the MI term (alpha in the paper). (default:
                        1)
  --btcvae-G BTCVAE_G   Weight of the dim-wise KL term (gamma in the paper).
                        (default: 1)
  --btcvae-B BTCVAE_B   Weight of the TC term (beta in the paper). (default:
                        6)

Evaluation specific options:
  --is-eval-only        Whether to only evaluate using precomputed model
                        `name`. (default: False)
  --is-metrics          Whether to compute the disentangled metrcics.
                        Currently only possible with `dsprites` as it is the
                        only dataset with known true factors of variations.
                        (default: False)
  --no-test             Whether not to compute the test losses.` (default:
                        False)
  --eval-batchsize EVAL_BATCHSIZE
                        Batch size for evaluation. (default: 1000)

Plot

Use python main_viz.py <model-name> <plot_types> <param> to plot using pretrained models. For example:

python main_viz.py btcvae_celeba_mini gif-traversals reconstruct-
                        traverse -c 7 -r 6 -t 2 --is-posterior

This will save the plots in the model directory results/<model-name>/. Generated plots for all experiments are found in their respective directories (created using ./bin/plot_all.sh).

Help

usage: main_viz.py ...

CLI for plotting using pretrained models of `disvae`

positional arguments:
  name                  Name of the model for storing and loading purposes.
  {generate-samples,data-samples,reconstruct,traversals,reconstruct-traverse,gif-traversals,all}
                        List of all plots to generate. `generate-samples`:
                        random decoded samples. `data-samples` samples from
                        the dataset. `reconstruct` first rnows//2 will be the
                        original and rest will be the corresponding
                        reconstructions. `traversals` traverses the most
                        important rnows dimensions with ncols different
                        samples from the prior or posterior. `reconstruct-
                        traverse` first row for original, second are
                        reconstructions, rest are traversals. `gif-traversals`
                        grid of gifs where rows are latent dimensions, columns
                        are examples, each gif shows posterior traversals.
                        `all` runs every plot.

optional arguments:
  -h, --help            show this help message and exit
  -s, --seed SEED       Random seed. Can be `None` for stochastic behavior.
                        (default: None)
  -r, --n-rows N_ROWS   The number of rows to visualize (if applicable).
                        (default: 6)
  -c, --n-cols N_COLS   The number of columns to visualize (if applicable).
                        (default: 7)
  -t, --max-traversal MAX_TRAVERSAL
                        The maximum displacement induced by a latent
                        traversal. Symmetrical traversals are assumed. If
                        `m>=0.5` then uses absolute value traversal, if
                        `m<0.5` uses a percentage of the distribution
                        (quantile). E.g. for the prior the distribution is a
                        standard normal so `m=0.45` corresponds to an absolute
                        value of `1.645` because `2m=90%` of a standard normal
                        is between `-1.645` and `1.645`. Note in the case of
                        the posterior, the distribution is not standard normal
                        anymore. (default: 2)
  -i, --idcs IDCS [IDCS ...]
                        List of indices to of images to put at the begining of
                        the samples. (default: [])
  -u, --upsample-factor UPSAMPLE_FACTOR
                        The scale factor with which to upsample the image (if
                        applicable). (default: 1)
  --is-show-loss        Displays the loss on the figures (if applicable).
                        (default: False)
  --is-posterior        Traverses the posterior instead of the prior.
                        (default: False)

Examples

Here are examples of plots you can generate:

  • python main_viz.py <model> reconstruct-traverse --is-show-loss --is-posterior first row are originals, second are reconstructions, rest are traversals. Shown for btcvae_dsprites:

    btcvae_dsprites reconstruct-traverse

  • python main_viz.py <model> gif-traversals grid of gifs where rows are latent dimensions, columns are examples, each gif shows posterior traversals. Shown for btcvae_celeba:

    btcvae_celeba gif-traversals

  • Grid of gifs generated using code in bin/plot_all.sh. The columns of the grid correspond to the datasets (besides FashionMNIST), the rows correspond to the models (in order: Standard VAE, β-VAEH, β-VAEB, FactorVAE, β-TCVAE):

    grid_posteriors

For more examples, all of the plots for the predefined experiments are found in their respective directories (created using ./bin/plot_all.sh).

Data

Current datasets that can be used:

The dataset will be downloaded the first time you run it and will be stored in data for future uses. The download will take time and might not work anymore if the download links change. In this case either:

  1. Open an issue
  2. Change the URLs (urls["train"]) for the dataset you want in utils/datasets.py (please open a PR in this case :) )
  3. Download by hand the data and save it with the same names (not recommended)

Our Contributions

In addition to replicating the aforementioned papers, we also propose and investigate the following:

Axis Alignment Metric

Qualitative inspections are unsuitable to compare models reliably due to their subjective and time consuming nature. Recent papers use quantitative measures of disentanglement based on the ground truth factors of variation v and the latent dimensions z. The Mutual Information Gap (MIG) metric is an appealing information theoretic metric which is appealing as it does not use any classifier. To get a MIG of 1 in the dSprites case where we have 10 latent dimensions and 5 generative factors, 5 of the latent dimensions should exactly encode the true factors of variations, and the rest should be independent of these 5.

Although a metric like MIG is what we would like to use in the long term, current models do not get good scores and it is hard to understand what they should improve. We thus propose an axis alignment metric AAM, which does not focus on how much information of v is encoded by z, but rather if each vk is only encoded in a single zj. For example in the dSprites dataset, it is possible to get an AAM of 1 if z encodes only 90% of the variance in the x position of the shapes as long as this 90% is only encoded by a single latent dimension zj. This is a useful metric to have a better understanding of what each model is good and bad at. Formally:

Axis Alignment Metric

Where the subscript (d) denotes the dth order statistic and Ix is estimated using empirical distributions and stratified sampling (like with MIG):

Mutual Information for AAM

Single Model Comparison

The model is decoupled from all the losses and it should thus be very easy to modify the encoder / decoder without modifying the losses. We only used a single model in order to have more objective comparisons of the different losses. The model used is the one from Understanding disentangling in β-VAE, which is summarized below:

Model Architecture

Losses Explanation

All the previous losses are special cases of the following loss:

Loss Overview

  1. Index-code mutual information: the mutual information between the latent variables z and the data variable x. There is contention in the literature regarding the correct way to treat this term. From the information bottleneck perspective this should be penalized. InfoGAN get good results by increasing the mutual information (negative α). Finally, Wasserstein Auto-Encoders drops this term.

  2. Total Correlation (TC): the KL divergence between the joint and the product of the marginals of the latent variable. I.e.* a measure of dependence between the latent dimensions. Increasing β forces the model to find statistically independent factors of variation in the data distribution.

  3. Dimension-wise KL divergence: the KL divergence between each dimension of the marginal posterior and the prior. This term ensures the learning of a compact space close to the prior which enables sampling of novel examples.

The losses differ in their estimates of each of these terms and the hyperparameters they use:

  • Standard VAE Loss: α=β=ɣ=1. Each term is computed exactly by a closed form solution (KL between the prior and the posterior). Tightest lower bound.
  • β-VAEH: α=β=ɣ>1. Each term is computed exactly by a closed form solution. Simply adds a hyper-parameter (β in the paper) before the KL.
  • β-VAEB: α=β=ɣ>1. Same as β-VAEH but only penalizes the 3 terms once they deviate from a capacity C which increases during training.
  • FactorVAE: α=ɣ=1, β>1. Each term is computed exactly by a closed form solution. Simply adds a hyper-parameter (β in the paper) before the KL. Adds a weighted Total Correlation term to the standard VAE loss. The total correlation is estimated using a classifier and the density-ratio trick. Note that ɣ in their paper corresponds to β+1 in our framework.
  • β-TCVAE: α=ɣ=1 (although can be modified), β>1. Conceptually equivalent to FactorVAE, but each term is estimated separately using minibatch stratified sampling.

Cite

When using one of the models implemented in this repo in academic work please cite the corresponding paper (linked at the top of the README). In case you want to cite this specific implementation then you can use:

@misc{dubois2019dvae,
  title        = {Disentangling VAE},
  author       = {Dubois, Yann and Kastanos, Alexandros and Lines, Dave and Melman, Bart},
  month        = {march},
  year         = {2019},
  howpublished = {\url{http://github.com/YannDubs/disentangling-vae/}}
}

disentangling-vae's People

Contributors

alecokas avatar gokceneraslan avatar linesd avatar yanndubs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

disentangling-vae's Issues

Training not using GPU

Thanks for the excellent repo!

I cloned the repo and installed the dependencies in a virtual environment. When training using the sample command in the README:

python main.py btcvae_celeba_mini -d celeba -l btcvae --lr 0.001 -b 256 -e 5

I see that the training is running on CPU, even though I am running on a machine with multiple cuda GPUs available and which are not being utilized by any other running processes.

...
12:36:40 INFO - main: Train celeba with 202599 samples
12:36:40 INFO - main: Num parameters in model: 504055
12:36:40 INFO - __init__: Training Device: cpu
...

Do I need to do something to enable GPU training? It looks like GPU training should be happening by default...

Error in losses explanation?

Hi,

Looking at this part of the readme, this doesn't seem right:

Standard VAE Loss: α=β=ɣ=1. Each term is computed exactly by a closed form solution (KL between the prior and the posterior). Tightest lower bound.
β-VAEH: α=β=ɣ>1. Each term is computed exactly by a closed form solution. Simply adds a hyper-parameter (β in the paper) before the KL.
β-VAEB: α=β=ɣ>1. Same as β-VAEH but only penalizes the 3 terms once they deviate from a capacity C which increases during training.

The standard VAE is simply gamma=1 with no alpha or beta. For Beta-VAE it is simply gamma > 0 with again no alpha or beta. Did I miss something?

Thanks.

FactorVAE encoder graph detaching during discriminator loss optimization

In FactorVAE, for the discriminator optimization, we don't want the gradients incurred in the VAE to update the VAE parameters, and thus the detach operation here:

z_perm = _permute_dims(latent_sample2).detach()

However, wouldn't we need to detach the first set of latent vectors as well (i.e. latent_sample1), now that we have to move the optimizer.step() at the end (due to the in-place modification error already addressed as one closed issue in the repo)?

I ran a debugging session and indeed gradient changes are observed on the VAE encoder from d_tc_loss.backward().

Getting the error "num_samples=0"

18:49:06 INFO - main: Root directory for saving and loading experiments: results\test01
Traceback (most recent call last):
File "main.py", line 252, in
main(args)
File "main.py", line 199, in main
logger=logger)
File "R:\disentangling-vae-master\utils\datasets.py", line 71, in get_dataloaders
**kwargs)
File "B:\Program_Files\Anaconda3\envs\dis-VAE\lib\site-packages\torch\utils\data\dataloader.py", line 802, in __init__
sampler = RandomSampler(dataset)
File "B:\Program_Files\Anaconda3\envs\dis-VAE\lib\site-packages\torch\utils\data\sampler.py", line 64, in __init__
"value, but got num_samples={}".format(self.num_samples))
ValueError: num_samples should be a positive integeral value, but got num_samples=0

Is there any place to set the value to > 0

Low MIG and AAM metrics

Hello,

Firstly, just wanted to state that this is a great repo with a very understandable code base!

I seem to be getting extremely low MIG / AAM scores (around 1e-3 to 1e-2) when training with any of the pretrained models, even using the recommended hyperparams in the .ini file in the main directory. Is this something you were noticing in your own tests?

Visual inspection of the traversals in DSprites seem to show that the network is learning quite disentangled representations (attached, with rows arranged in order of descending KL-divergence from Gaussian prior), so I am quite confused as to why the MIG score is so low.

Even introducing supervision (matching latent factors to generative factors, the maximum MIG score I have been able to attain is around 0.01, but AAM is a lot higher, at around 0.6 for the model that produced the attached latent traversals.

Cheers,
Justin

traversals

question on getting zero kl-divergence

Hi @YannDubs,

I have a problem in my own code, and I do not know how to solve that as I am new to VAEs models,
My kl divergence loss becomes so small close to zero, I did put some annealing function , but still the KL-loss becomes close too zero even when the annealing weight is zero,
In your opinion what should I do to train correctly :( ?

image

Thanks,

evaluate.py compute_losses?

Hi! Thank u for this wonderful work!
In evaluate.py-->compute_losses(self, dataloader), it seems that only one batch of data is used for evaluation.
But when it comes to loss computation,

losses = {k: sum(v) / len(dataloader) for k, v in storer.items()}

it uses len(dataloader) to average the loss. Should that be the length of element v?
I wonder if I misunderstand the above computation.
Any help will be appreciated!

negetive total correlation loss for btc-vae

Nice Work!!!!!!!!!
I tried the beta-TC VAE, but I found that tc_loss is negetive. Actually, this term is KL divergence which is always positive.
I am confused about it.
Thanks!!!!

FashionMNIST background_color not set

Just a small issue with quick fix, seems the FashionMNIST class doesn't have the background_color property set.

class FashionMNIST(datasets.FashionMNIST):

I ran python main.py btcvae_fashion -d fashion and got the error

File "./disentangling-vae/utils/datasets.py", line 46, in get_background
    return get_dataset(dataset).background_color
AttributeError: type object 'FashionMNIST' has no attribute 'background_color'

TC-BetaVAE's MSS Question

Thank u for your code.
I have a question for TC-BetaVAE's MSS

I don't understand why makes all zero columns 1/N and makes all 1st columns strat_weight i n log_importance_weight_matrix


def log_importance_weight_matrix(batch_size, dataset_size):
    N = dataset_size
    M = batch_size - 1
    strat_weight = (N - M) / (N * M)
    W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
    W.view(-1)[::M + 1] = 1 / N
    W.view(-1)[1::M + 1] = strat_weight
    W[M - 1, 0] = strat_weight
    return W.log()

Below is the formula for the thesis.

스크린샷 2020-08-12 오전 2 28 10

I understand that all diagonal entry should be 1/N(because in that case, z is sampled from q(z|n*)) and some other entry should be strat_weight. what am I wrong about?

Duplicating hyperparameters when training a FactorVAE

Hi!

I've been playing a little bit with the code(congratulations for the work by the way 😄 ), and I've seen that when training a FactorVAE model, both the batch size and number of epochs are duplicated:

disentangling-vae/main.py

Lines 191 to 194 in f045219

if args.loss == "factor":
logger.info("FactorVae needs 2 batches per iteration. To replicate this behavior while being consistent, we double the batch size and the the number of epochs.")
args.batch_size *= 2
args.epochs *= 2

Does anybody know the reason behind this operation? I've reviewed the original paper but I couldn't find anything related to this.
Thanks for the help!

Readme Losses Explanation

Hi. Thanks for this wonderful repo!

My question is, shouldn't beta be equal to 0 for the standard VAE loss (i.e. there's no correlation term)?

Also, wouldn't it be clearer to switch beta and gamma hyperparameter symbols, considering that we have literature like betaVAE which uses beta as the hyperparameter for the Dimension-wise KL Divergence?

Ediy: I just found a similar question #55 . Closing this now 👍

imageio.mimsave error

Got an error while running:

python main.py btcvae_celeba_mini -d dsprites -l celeba --lr 0.001 -b 256 -e 5

Pointing to ./utils/visualize.py line 429:

imageio.mimsave(self.save_filename, self.images, fps=FPS_GIF)

Changing the argument from fps to duration in the following code which converts from frames per second (fps) to duration seem to resolve the issue

imageio.mimsave(self.save_filename, self.images, duration=(1000 * 1/FPS_GIF))

Doubts about the calculation of H_z

The problem is here log_q_zCx = log_density_gaussian(samples_zCx[..., idcs], mean[..., idcs], log_var[..., idcs]).

q_zCx is a density function, which should be integrated.
Therefore, I write a probability implement. q_zCx calculates P(a<=z<=b|x).

def erf(x):
    a1 = 0.278393
    a2 = 0.230389
    a3 = 0.000972
    a4 = 0.078108
    s = torch.sign(x)
    x = x.abs()
    e = 1-1/(1 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)**4
    return s*e


def Gab(a,b,mu,sigma):
    '''
    the probability of z belonging to [a,b]
    :param a:
    :param b:
    :param mu:
    :param sigma:
    :return:
    '''
    inverse_sigma = 1/(math.sqrt(2)*sigma)
    return 0.5 * (erf((b-mu)*inverse_sigma)-
                     erf((a-mu)*inverse_sigma))

samples,params,recons,labels = evaluator.compute(test_loader)

N_x_samples = 1000
M_z_samples = 100
mu,logvar = params
mu,logvar = mu[:N_x_samples],logvar[:N_x_samples]

mu =mu.view(1,N_x_samples,dim).repeat(M_z_samples,1,1)
logvar =logvar.view(1,N_x_samples,dim).repeat(M_z_samples,1,1)

l = torch.linspace(-3,3,M_z_samples+1)
d = l[1]-l[0]
a = l[:-1].cuda()
b = l[1:].cuda()

a = a.reshape(-1,1,1).expand(M_z_samples,N_x_samples,dim)
b = b.reshape(-1,1,1).expand(M_z_samples,N_x_samples,dim)


q_zCx = Gab(a,b,mu,torch.exp(0.5*logvar))
q_z = q_zCx.mean(1)
H_z = (-q_z*(q_z/d).log()).sum(0) ```

Dataset with incomplete combinations

Hi! I am using this code to work on a dataset of ~700 words. For each word I am varying several variables (size, font, position, etc) . This results in a too big dataset (+6M instances) to use all the possible combinations during training, so I decided to use a sample of the full dataset. That is pl for the training, but this creates an issue during the evaluation run.

In particular, in evaluate.compute_metrics() I found the first technical issue. To run this method the code tries to reshape samples_zCx and params_zCx tensors using the sizes of the dataset generation factors (lat_sizes) and the latent layer size (latent_dim). This is not a problem when using a dataset with all the possible combinations, but given that I now have a sample of all the possibilities, this is not the case. So, I cannot make the reshape.

I solved this by creating a tensor of np.nan and filling it with the available data in the corresponding cells (using metadata from the dataset that indicates how each instance was created). Technically, this works, but I now have doubts about how this solutions impacts on the following calculations. That is, I now have a tensor with NANs that will be used to compute the conditional entropy H(z|v), is this ok? Would it better to use zeros?

Additionally, computing the conditional entropy with the _estimate_H_zCv() method is pretty computationally expensive given that I have a big tensor full of NANs. Would it be ok to skip the cells with NANs to speedup the process?

Low MIG values bug found & solution

I trained a beta TCVAE with the code from https://github.com/rtqichen/beta-tcvae which gives MIG for beta TCVAE of ~0.50. When computing MIG with your code with the same model (based on MLP), I had values close to 0.0008.

Differences with Chen's code I found important:

  • MIG values are not computed on shape in Chen's code (not considered a factor of variation). I had to modify the dsprites dataset to remove shape from dSprites lat_names, and write a custom _estimate_H_zCv function. I can share if you want.

  • Chen uses samples, not the mean as you do here

    since self.training is False

  • The most important change is I changed these lines

    samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
    for

      samples_zCx = samples_zCx.permute(1,0)
      samples_zCx = samples_zCx.index_select(1, samples_x).view(latent_dim, n_samples)
      samples_zCx = samples_zCx.view(1, latent_dim, n_samples).expand(len_dataset, latent_dim, n_samples)
      mean = params_zCX[0].view(len_dataset, latent_dim, 1).expand(len_dataset, latent_dim, n_samples)
      log_var = params_zCX[1].view(len_dataset, latent_dim, 1).expand(len_dataset, latent_dim, n_samples)
    

which are closer to Chen's code, and I get values of ~0.50 now too. I don't exactly know why the original lines where not expanding the correct way

PlotNeuralNet Code

Hello! I was kindly wondering if you could share your PlotNeuralNet code (which is what I presume you used to generate this). I would really appreciate it! To my knowledge, there aren't any VAE visualization examples for this online, so I imagine it would be very helpful for others too.

A few snags

Apologies if the following are just me not configuring properly:

  1. I'm not seeing Gifs/Pngs being created in the model directory with e.g.
    python main.py factor_celeba_cc -x factor_celeba -d celeba -l factor. I'd love to have these created every epoch

  2. when i try to evaluate after stopping a training run I get
    ization.py", line 382, in load f = open(f, 'rb') FileNotFoundError: [Errno 2] No such file or directory: 'results/best_celeba/model.pt'
    To work around I then rename the e.g. model-100.pt in the model directory to model.pt, then I get
    21:50:09 INFO - main: Root directory for saving and loading experiments: results/ best_celeba Traceback (most recent call last): File "main.py", line 252, in <module> main(args) File "main.py", line 233, in main test_loader = get_dataloaders(metadata["dataset"], KeyError: 'dataset'

Incidently when i run e.g. python main_viz.py best_celeba gif-traversals reconstruct-traverse -c 7 -r 6 -t 2 --is-posterior I get the same KerError: 'dataset' error

Thanks for the fantastic repo though!

Computing MIG and AAM for other datasets

I am trying to compute MIG and AAM for another dataset which has a different structure from dsprites, in the sense that the number of samples does not match the product of the size of each latent. Thus, the line fails

samples_zCx = samples_zCx.view(*lat_sizes, latent_dim)

since the size of samples_zCx is (len(dataset), latent_dim) but len(dataset) != *lat_sizes. Any reason why you explicitly choose to use the product of latent sizes, or should it be the length of the dataset?

Thanks !

Minor bug in loss logging

Line 109 in disvae/models/losses.py:

if not is_train or self.n_train_steps % self.record_loss_every == 1:

This does not work when self.record_loss_every is 1 (recording all mini-batches).

Fix:

if not is_train or (self.n_train_steps - 1) % self.record_loss_every == 0:

Inplace error when running FactorVAE

When running python main.py factor_coloredmnist -x factor_coloredmnist on Python 3.8.5 I get the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 20]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Coming from d_tc_loss.backward()

I tried replacing inplace=False in the leaky_relu of the discriminator without success. The error comes from calling F.cross_entropy(d_z, zeros) in d_tc_loss (the term F.cross_entropy(d_z_perm, ones) poses no problem).

Any help would be appreciated :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.