Simulation-based inference in JAX
Sbijax
is a Python library for neural simulation-based inference and
approximate Bayesian computation using JAX.
In addition, sbijax
offers minimal functionality to compute model
diagnostics and for visualizing posterior distributions.
Concretely, sbijax
implements
- Sequential Monte Carlo ABC (
SMCABC
) - Neural Likelihood Estimation (
SNL
) - Surjective Neural Likelihood Estimation (
SSNL
) - Neural Posterior Estimation C (short
SNP
) - Contrastive Neural Ratio Estimation (short
SNR
) - Neural Approximate Sufficient Statistics (
SNASS
) - Neural Approximate Slice Sufficient Statistics (
SNASSS
) - Flow matching posterior estimation (
SFMPE
) - Consistency model posterior estimation (
SCMPE
)
where the acronyms in parentheses denote the names of the classes in sbijax
. It builds on the Python packages Surjectors, Haiku,
Distrax and BlackJAX.
Caution
Sbijax
implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:
from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd
def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
), batch_ndims=0)
return prior
def simulator_fn(seed, theta):
p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
y = theta["theta"] + p.sample(seed=seed)
return y
fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))
y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
More self-contained examples can be found in examples.
Documentation can be found here.
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install from PyPI, just call the following on the command line:
pip install sbijax
To install the latest GitHub , use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.
In order to contribute:
- Clone
sbijax
and installhatch
viapip install hatch
, - create a new branch locally
git checkout -b feature/my-new-feature
orgit checkout -b issue/fixes-bug
, - implement your contribution and ideally a test case,
- test it by calling
make tests
,make lints
andmake format
on the (Unix) command line, - submit a PR ๐
Note
๐ The API of the package is heavily inspired by the excellent Pytorch-based sbi
package.
Simon Dirmeier sfyrbnd @ pm me