Coder Social home page Coder Social logo

dirmeier / sbijax Goto Github PK

View Code? Open in Web Editor NEW
15.0 3.0 2.0 16.31 MB

Simulation-based inference in JAX

Home Page: https://sbijax.rtfd.io

License: Apache License 2.0

Python 99.82% Makefile 0.18%
abc approximate-bayesian-computation python simulation-based-inference normalizing-flows smc-abc

sbijax's Introduction

sbijax

active ci codecov documentation version

Simulation-based inference in JAX

About

Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. In addition, sbijax offers minimal functionality to compute model diagnostics and for visualizing posterior distributions.

Concretely, sbijax implements

where the acronyms in parentheses denote the names of the classes in sbijax. It builds on the Python packages Surjectors, Haiku, Distrax and BlackJAX.

Caution

โš ๏ธ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Examples

Sbijax implements a slim object-oriented API with functional elements stemming from JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm. For example, you can define a neural likelihood estimation method and generate posterior samples like this:

from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install from PyPI, just call the following on the command line:

pip install sbijax

To install the latest GitHub , use:

pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone sbijax and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling make tests, make lints and make format on the (Unix) command line,
  5. submit a PR ๐Ÿ™‚

Acknowledgements

Note

๐Ÿ“ The API of the package is heavily inspired by the excellent Pytorch-based sbi package.

Author

Simon Dirmeier sfyrbnd @ pm me

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.