Coder Social home page Coder Social logo

ipl-uv / rbig_jax Goto Github PK

View Code? Open in Web Editor NEW
6.0 3.0 1.0 113.24 MB

Iterative and Parametric Gaussianization with JAX.

Home Page: https://ipl-uv.github.io/rbig_jax/

License: MIT License

Makefile 0.01% Python 0.45% Jupyter Notebook 99.55%
gaussianization density-estimation rbig information-theory jax sampling generative-model

rbig_jax's Introduction

Iterative and Parametric Gaussianization with Jax

This package implements the Rotation-Based Iterative Gaussianization (RBIG) algorithm using Jax. It is a normalizing flow algorithm that can transform any multi-dimensional distribution into a Gaussian distribution using a sequence of simple marginal Gaussianization transforms (e.g. histogram) and rotations (e.g. PCA). It is invertible which means you can calculate probabilities as well as sample from your distribution. Seen the example below for details.


Density Estimation Demo

Demo Colab Notebooks

  • Iterative Gaussianization - Open In Collab
  • Parametric Gaussianization - Open In Collab
Demo
Original Data Gaussian Transform Inverse Transform
Samples Drawn Probabilities

Why Jax?

Mainly because I wanted to practice. It's an iterative scheme so perhaps Jax isn't the best for this. But I would like to improve my functional programming skills. In addition, Jax is much faster because of the jit compilation and autobatching. So it handles some difficult aspects of programming a lot easier. Also, the same code can be used for CPU, GPU and TPU with only minor changes. Overall, I didn't see any downside to having some free speed-ups.


Installation Instructions

This repo uses the most updated jax library on github so this is absolutely essential, e.g. it uses the latest np.interp function which isn't on the pip distribution yet. The environment.yml file will have the most updated distribution.

  1. Clone the repository.
git clone https://github.com/IPL-UV/rbig_jax
  1. Install using conda.
conda env create -f environment.yml
  1. If you already have the environment installed, you can update it.
conda activate jaxrbig
conda env update --file environment.yml

Resources

  • Python Code - github
  • RBIG applied to Earth Observation Data - github
  • Original Webpage - ISP
  • Original MATLAB Code - webpage
  • Original Python Code - github
  • Paper - Iterative Gaussianization: from ICA to Random Rotations

Acknowledgements

This work was supported by the European Research Council (ERC) Synergy Grant “Understanding and Modelling the Earth System with Machine Learning (USMILE)” under Grant Agreement No 855187.

rbig_jax's People

Contributors

jejjohnson avatar miguelangelft avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

alexhepburn

rbig_jax's Issues

[Investigate] "Squeezing Layer" for Marginal Histogram Transformation.

We still have issues with the boundaries for the marginal histogram transformation. When we do the sampling, that's when we see it the most.

Proposal

It would be interesting to see what would happen if we do a "squeezing layer" to constrain the domain space. So instead of the histogram transformation on the input domain (-inf, inf) -> [0,1], we could implement an invertible "squeezing" function to the domain, (-inf, inf) -> [0, 1], and then do the histogram transformation, [0,1]->[0,1]. At the very least, we don't have to worry about the bounds for the histogram function.

[Algorithm][Parametric] Emerging Convolutions

Probably the only implementation that allows for invertible convolutions which spatial awareness. It is related to the autoregressive method so it will be fast for density estimates but slow for sampling.


Resources

[Algorithm][Parametric] SVD Linear Layer

We currently have the householder transformation but we can use that for the SVD layer, X=USV^T. The rotation matrices (U,V^T) are constrained to be orthogonal and the diagonal, S, is unconstrained. It's a bit more expensive but it might help for training.


Resources

[Parametric] Remove Dependency on objax

objax is basically a PyTorch-like version of Jax. But it is a bit limiting when trying to mix and match reviews. So I think we should remove the dependency on objax and stick with pure Jax.


Example

This example was taken from jax-flows.

Demo Snippet ```python def FixedInvertibleLinear(): """An implementation of an invertible linear layer from `Glow: Generative Flow with Invertible 1x1 Convolutions` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """
def init_fun(rng, input_dim, **kwargs):
    W = orthogonal()(rng, (input_dim, input_dim))
    W_inv = linalg.inv(W)
    W_log_det = np.linalg.slogdet(W)[-1]

    def direct_fun(params, inputs, **kwargs):
        outputs = inputs @ W
        log_det_jacobian = np.full(inputs.shape[:1], W_log_det)
        return outputs, log_det_jacobian

    def inverse_fun(params, inputs, **kwargs):
        outputs = inputs @ W_inv
        log_det_jacobian = np.full(inputs.shape[:1], -W_log_det)
        return outputs, log_det_jacobian

    return (), direct_fun, inverse_fun

return init_fun

</details>

[Algorithm][Parametric] Implement the GDN layer

A basic implementation of the GDN algorithm. It is mainly used in compression. It features a normalization which can be coupled with a convolution/linear layer.

y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))

Resources


Potential Issues

Somehow we need to constrain a few parameters, e.g., gamma and beta. Probably a simple np.max would work.

ITM Notebook

A notebook showcasing information theoretic metrics:

  • Probability & Information
  • Total Correlation
  • Entropy
  • Mutual Information
  • KL-Divergence

Note: we really want to focus on the speed and simplicity. E.g. We can show how one can do it from scratch as well as the convenient wrapper functions to explain the design decisions.

[Demo] Info Loss for GaussFlow versus RBIG

It would be nice to see the differences between the information loss/reduction for the GaussFlow algorithm and the IterGauss algorithm.

Outcome: A demo notebook showcasing the info loss between layers for both algorithms.

Deep Dive Demo Notebook

I have notebooks documenting the progression to building and RBIG model but there is no all-in-one notebook showcasing

Examples components:

  • Marginal Gaussianization
  • Univariate Entropy
  • Information Loss
  • ITMs -> TC, H

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.