Coder Social home page Coder Social logo

lifeitech / fce-2d Goto Github PK

View Code? Open in Web Editor NEW
7.0 1.0 1.0 1.47 MB

Flow Contrastive Estimation (FCE) PyTorch Implementation on 2D data

License: MIT License

Python 100.00%
deep-learning pytorch machine-learning generative-learning energy-model fce nce

fce-2d's Introduction

Flow Contrastive Estimation (FCE)

This is an implementation of Flow Contrastive Estimation in PyTorch on 2D dataset.

Introduction

Direct Estimation of Energy Model is Difficult

Our problem is to estimate an energy based model (EBM)

$$p_\theta(x) = \frac{\exp[-f_\theta(x)]}{Z(\theta)}$$

where

$$Z(\theta) = \int\exp[-f_\theta(x)]dx$$

is the normalizing constant. The energy model specifies a probability distribution on data space. The normalizing constant is very difficult to calculate since we have to sum over an exponential number of configurations.

The energy based model is implemented in file ebm.py.

NCE: Teach EBM to Classify Data and Noise

One approach to estimate EBM is through Noise Contrastive Estimation (NCE). In NCE, the normalizing constant is treated as a trainable parameter, and the model parameters are estimated by training the EBM to classify data and noise. Let $p_{\mathrm{data}}(x)$ denote data distribution and let $q(x)$ denote some noise distribution. This amounts to maximize the following posterior log-likelihood of the classification:

$$V(\theta) = \mathbb{E}_{x\sim p_{\text{data}}}\log\frac{p_\theta(x)}{p_\theta(x)+q(x)} + \mathbb{E}_{\tilde{x}\sim q}\log\frac{q(\tilde{x})}{p_\theta(\tilde{x}) + q(\tilde{x})}.$$

FCE: Replace Noise in NCE with Flow Model

In Flow Contrastive Estimation (FCE), we replace the noise $q(x)$ with a flow model $q_\alpha(x)$, and jointly train the two models by iteratively maximizing and minimizing the posterior log-likelihood of the classification:

$$V(\alpha,\theta) = \mathbb{E}_{x\sim p_{\text{data}}}\log\frac{p_\theta(x)}{p_\theta(x)+q_\alpha(x)} + \mathbb{E}_{\tilde{x}\sim q_\alpha}\log\frac{q_\alpha(\tilde{x})}{p_\theta(\tilde{x}) + q_\alpha(\tilde{x})}.$$

This objective is implemented as the value function in file util.py.

When the classification accuracy is below a threshold, the energy model is trained. Otherwise, the flow model is trained.

In the paper, the authors choose Glow as the flow model. In this repository we implemented both Glow and MAF as the flow model.

Training

To train the model, do

python train.py
Argument Meaning
--seed=42 random seed
--epoch=100 training epoch
--flow=glow glow or maf to use as the flow model
--threshold=0.6 threshold for alternate training
--batch=1000 batch size
--dataset=8gaussians Available datasets:
- 8gaussians
- spiral
- 2spirals
- checkerboard
- rings
- pinwheel
--samples=500000 training set size
--lr_ebm=1e-3 Adam learning rate for EBM
--lr_flow=7e-4 Adam learning rate for Flow model
--b1=0.9 Adam gradient decay
--b2=0.999 Adam gradient decay

Install wandb

To run the script you need to install Weights & Biases (wandb). It is an MLOps tool used to monitor the metrics during training. I find it very easy and convenient to use, and I encourage you to install and have a try as well.

First, sign up on their website https://wandb.ai/site. You may use your GitHub account to sign up. Copy your API key.

Next, install the Python package through

pip install wandb

When you start running your experiment, you may be asked to login. You simply paste your API key and hit enter. Now that the experiment is running, you can go to https://app.wandb.ai/home and visualize your experiment in real time.

Otherwise, if you do not wish to use it, you can comment out all the wandb part in train.py.

Visualizations

Density Plots

The figure below shows result of Flow Contrastive Estimation, using MAF as the flow model. The left column has three data distributions. The middle column shows densities learned by MAF. Note that they are also the densities that the energy model is trying to distinguish from the true densities. The right column shows densities learned by EBM.

fce-maf

For reference, here is the result presented in the FCE paper, showing learned Glow and EBM densities on three data distributions.

fce-glow

MSE

In case of the 8 Gaussian dataset, we have an analytical formula for the true data distribution. We can evaluate the MSE of the log density learned by the energy model versus the true data distribution. The plot below shows the MSE on the 8gaussians training dataset.

mse

Value

The figure below shows the negative of the value function during training. If both EBM $p_\theta$ and Flow $q_\alpha$ are close to the data distribution $p_{\text{data}}$, then $p_\theta\approx q_\alpha\approx p_{\text{data}}$ and the value should be approximately

$$- V(\alpha,\theta)\approx -\left(\log\frac{1}{2}+\log\frac{1}{2}\right) = \log4 \approx 1.39.$$

value

Accuracy

In our experiment we choose 0.6 as the default threshold, and as we can see the classification accuracy of the EBM fluctuates around 0.6.

acc

Reference

Gao, Ruiqi, et al. "Flow contrastive estimation of energy-based models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.

fce-2d's People

Contributors

lifeitech avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

pkulwj1994

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.