Coder Social home page Coder Social logo

wide_bnn_sampling's Introduction

Wide BNN sampling

Code for the paper "Simplicity of the wide Bayesian neural network weight posterior: theory and accelerated sampling" by Jiri Hron, Roman Novak, Jeffrey Pennington, and Jascha Sohl-Dickstein.

The main contribution is a reparametrisation of Bayesian neural network (BNN) posteriors which enables 10-200x faster mixing compared to standard parametrisation when combined with Hamiltonian Monte Carlo. Intriguingly, the sampling speed becomes higher the larger the BNN. The reparametrisation is derived using the large width BNN theory (e.g., Matthews et al., Lee et al., Hron et al.), and can be shown to transform the exact BNN weight space posterior into a distribution whose KL divergence from the multivariate standard normal distribution vanishes in the large width limit. This is the source of the speed-up at large width, but we have sometimes observed 10x faster mixing even when very far from the wide regime (i.e., when width is much smaller than the dataset size).

The code in this repository provides an efficient way of computing both the reparametrised density and the parameters at the same time. As detailed in the paper, the implementation is based on Cholesky decomposition, and a forward and backward solve akin to the usual implementation of the Cholesky solver.

We rely on JAX, a high-performance machine learning library based on XLA with simple NumPy/Autograd like API, and Neural Tangents, a high-level neural network API enabling computation with finite as well as infinite neural networks. See setup.py for other dependencies.

Using the code

The code has several dependencies described in setup.py. To install them automatically, use

git clone https://github.com/google/wide_bnn_sampling
cd wide_bnn_sampling
pip install -e .

A dependency not included is jaxlib whose installation differs based on the available hardware; please follow the relevant instructions from JAX's repository. If you want to just quickly try the code with CPU backend, you can run pip install jax jaxlib --upgrade.

To set off an experiment, you can modify the provided config.py as needed, and invoke

python3 wide_bnn_sampling/main.py --config wide_bnn_sampling/config.py --store_dir <results-directory>

The high-level structure of main.py dependencies is descibed below:

  • config.py: Configuration flags for the dataset, neural network architecture, the sampler, and auxiliary experiment run settings.
  • datasets.py: Loading and preprocessing of data.
  • measurements.py: Logging utilities.
  • models.py: Constructs neural network architectures with Neural Tangents.
  • reparametrisation.py: Effective implementation of the reparametrisation under the assumption of Gaussian likelihood and prior (details in the paper).
  • samplers.py: Custom implementation of HMC/LMC, and Metropolis-Hastings with a simple Gaussian proposal.
  • utils.py: Auxiliary methods primarily used within main.py.

CAVEAT: Despite using several tricks for improved stability, we observed significant deterioration of acceptance probabilities when computational precisions is low. We recommend using at least float32, but preferring float64 where feasible. The relevant flags in JAX are jax_enable_x64 (and jax_default_matmul_precision if on TPU).

Contributing

See CONTRIBUTING.md for details.

License

Apache 2.0; see LICENSE for details.

Disclaimer

This project is not an official Google project. It is not supported by Google and Google specifically disclaims all warranties as to its quality, merchantability, or fitness for a particular purpose.

wide_bnn_sampling's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.