High-dimensional Bayesian inference with Python and Jax.
jax-bayes is designed to accelerate research in high-dimensional Bayesian inference, specifically for deep neural networks. It is built on Jax.
NOTE: the jax_bayes.mcmc
api was updated on 02/05/2022 to version 0.1.0 and is not backwards compatible with the previous version 0.0.1. The changes are minor, and they fix a significant bug. See this PR for more details.
jax-bayes supports two different methods for sampling from high-dimensional distributions:
- Markov Chain Monte Carlo (MCMC) which iterates a Markov chain which has an invariant distribution (approximately) equal to the target distribution
- Variational Inference (VI): which finds the closest (in some sense) distribution in a parameterized family of distributions to the target distribution.
jax-bayes allows you to "bring your own JAX-based network to the Bayesian ML party" by providing samplers that operate on arbitrary data structures of JAX arrays and JAX transformations. You can also define your own sampler in terms of JAX arrays and lift them to general-purpose samplers (using the same approach as in jax.experimental.optimizers
)
You can easily modify this Haiku quickstart example to support bayesian inference:
# ---- From the Haiku Quickstart ----
import jax.numpy as jnp
import haiku as hk
def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
def logprob_fn(batch):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(batch['images'])
return - jnp.mean(softmax_cross_entropy(logits, batch['labels']))
logprob = hk.transform(logprob_fn)
# ---- With jax-bayes ----
#instantiate the sampler
key = jax.random.PRNGKey(0)
from jax_bayes.mcmc import langevin_fns
init, propose, accept, update, get_params = langevin_fns(key, lr=1e-3)
#define the mcmc step
@jax.jit
def mcmc_step(state, keys, batch):
params = get_params(state)
batch_logprob = lambda p: logprob.apply(p, None, batch)
#use vmap + grad to compute per-sample gradients
g = jax.vmap(jax.grad(batch_logprob))(params)
#omiting some unused arguments for this example
propose_state, new_keys = propose(g, state, keys, ...)
accept_idxs, new_keyes = accept(g, state, ..., prop_state, ...) # not ncessary for langevin algorithm
next_state, new_keys = update(accept_idxs, state, propose_state, new_keys, ...)
return next_state, new_keys
#initialize the sampler state
params = logprob.init(jax.random.PRNGKey(1), next(dataset))
sampler_state, keys = init(params)
#run the mcmc algorithm
for i in range(1000):
sampler_state, keys = mcmc_step(sampler_state, keys, next(dataset))
# extract your samples
sampled_params = get_params(sampler_state)
Sometimes we want our neural networks to say "I don't know" (think self-driving cars, machine translation, etc) but, as illustrated in this paper or examples/deep/mnist
, the logits of a neural network should not serve a substitute for uncertainty. This library allows you to model weight uncertainty about the data by sampling from the posterior rather than optimizing it. You can also take advantge of occam's razor and other benefits of Bayesian statistics.
jax-bayes requires jax>=0.1.74 and jaxlib>=0.1.15 as separate dependencies, since jaxlib needs to be installed for the accelerator (CPU / GPU / TPU).
Assuming you have jax + jaxlib installed, install via pip:
pip install git+https://github.com/jamesvuc/jax-bayes
jax_bayes.mcmc
contains the MCMC functionality. It provides:jax_bayes.mcmc.sampler
which is the decorator that "tree-ifies" a sampler's methods. A sampler is defined as a callable returning a tuple of functions
where the returned functions have specific signatures.def sampler(*args, **kwargs): ... return init, log_proposal, propose, update, get_params
- A bunch of samplers:
jax_bayes.mcmc.langevin_fns
(Unadjusted Langevin Algorithm)jax_bayes.mcmc.mala_fns
(Metropolis Adjusted Langevin Algorithm)jax_bayes.mcmc.rk_langevin_fns
(stochastic Runge Kutta solver for the continuous-time Langevin dyanmics)jax_bayes.mcmc.hmc_fns
(Hamitonian Monte Carlo algorithm)jax_bayes.mcmc.rms_langevin_fns
(preconditioned Langevin algorithm using the smoothed root-mean-square estimate of the gradient as the preconditionner matrix (like RMSProp))jax_bayes.mcmc.rwmh_fns
implements (Random Walk Metropolis Hastings Algorithm.)
jax_bayes.mcmc.bb_mcmc
wraps a given sampler into a "black-box" function suitable for sampling from simple densities (e.g. without sampling batches).
jax_bayes.variational
contains the variational inference functionality. It provides:jax_bayes.variational.variational_family
which is a decorator that tree-ifies the variational family's methods. A variational family is defined as a callable returning a tuple of functions
where the returned functions have specific signatures. The returned object is not, however, a tree-ified collection of functions but a class that contains these functionsdef variational_family(*args, **kwargs): ... return init, sample, evaluate, get_samples, next_key, entropy
jax_bayes.variational.diag_mvn_fns
(diagonal multivariate gaussian family)
We have provided some diverse examples, some of which are under active development --- see examples/
for more details. At a high level, we provide:
- Shallow examples for sampling from regular probability distributions using MCMC and VI.
- Deep examples for doing deep Bayesian ML (mostly with Colab)
- Neural Network Regession
- MNIST with 300-100-10 MLP
- CIFAR10 with a CNN
- Attention-based RNN Neural Machine Translation
Note: If you are familiar with ML and are looking to learn how to use JAX, these examples include regular ML versions that are relatively self-contained
mcmc | nn regression |
---|---|