This is PyTorch implementation of the paper:
Lagging Inference Networks and Posterior Collapse in Variational Autoencoders
Junxian He, Daniel Spokoyny, Graham Neubig, Taylor Berg-Kirkpatrick
ICLR 2019
The code performs aggressive training of inference network to mitigate the issue of posterior collapse in VAE and obtain better generative modeling performance.
Please contact [email protected] if you have any questions.
- Python 3
- PyTorch 0.4
Datasets used in this paper can be downloaded with:
python prepare_data.py
Downloaded data is located in ./datasets/
.
Example script to train VAE on text data:
python text.py --dataset yahoo --aggressive 1 --warm_up 10 --kl_start 0.1
image data:
python image.py --dataset omniglot --aggressive 1 --warm_up 10 --kl_start 0.1
Here:
-
--dataset
specifies the dataset name, currently it supportssynthetic
,yahoo
,yelp
fortext.py
andomniglot
forimage.py
-
--aggressive
controls whether applies aggressive training or not -
--kl_start
represents starting KL weight (set to 1.0 to disable KL annealing) -
--warm_up
represents number of annealing epochs (KL weight increases fromkl_start
to 1.0 linearly in the firstwarm_up
epochs)
To run the code on your own text/image dataset, you need to create a new configuration file in ./config/
folder to specifiy network hyperparameters and datapath. If the new config file is ./config/config_abc.py
, then --dataset
needs to be set as abc
accordingly.
@inproceedings{he2018lagging,
title={Lagging Inference Networks and Posterior Collapse in Variational Autoencoders},
author={Junxian He and Daniel Spokoyny and Graham Neubig and Taylor Berg-Kirkpatrick},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=rylDfnCqF7},
}