Coder Social home page Coder Social logo

nuts's Introduction

No-U-Turn Sampler (NUTS) for python

This package implements the No-U-Turn Sampler (NUTS) algorithm 6 from the NUTS paper (Hoffman & Gelman, 2011).

Content

The package mainly contains:

  • nuts.nuts6 return samples using the NUTS
  • nuts.numerical_grad return numerical estimate of the local gradient
  • emcee_nuts.NUTSSampler emcee NUTS sampler, a derived class from emcee.Sampler

A few words about NUTS

Hamiltonian Monte Carlo or Hybrid Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that avoids the random walk behavior and sensitivity to correlated parameters, biggest weakness of many MCMC methods. Instead, it takes a series of steps informed by first-order gradient information.

This feature allows it to converge much more quickly to high-dimensional target distributions compared to simpler methods such as Metropolis, Gibbs sampling (and derivatives).

However, HMC's performance is highly sensitive to two user-specified parameters: a step size, and a desired number of steps. In particular, if the number of steps is too small then the algorithm will just exhibit random walk behavior, whereas if it is too large it will waste computations.

Hoffman & Gelman introduced NUTS or the No-U-Turn Sampler, an extension to HMC that eliminates the need to set a number of steps. NUTS uses a recursive algorithm to find likely candidate points that automatically stops when it starts to double back and retrace its steps. Empirically, NUTS perform at least as effciently as and sometimes more effciently than a well tuned standard HMC method, without requiring user intervention or costly tuning runs.

Moreover, Hoffman & Gelman derived a method for adapting the step size parameter on the fly based on primal-dual averaging. NUTS can thus be used with no hand-tuning at all.

In practice, the implementation still requires a number of steps, a burning period and a stepsize. However, the stepsize will be optimized during the burning period, and the final values of all the user-defined values will be revised by the algorithm.

reference: arXiv:1111.4246: Matthew D. Hoffman & Andrew Gelman, 2011, "The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo"

Binder

Example Usage

sampling a 2d highly correlated Gaussian distribution see nuts.test_nuts6

  • define a log-likelihood and gradient function:
def correlated_normal(theta):
    """ Example of a target distribution that could be sampled from using NUTS.  (Doesn't include the normalizing constant.)
    Note: 
    cov = np.asarray([[1, 1.98],
                      [1.98, 4]])
    """

    #A = np.linalg.inv( cov )
    A = np.asarray([[50.251256, -24.874372],
                    [-24.874372, 12.562814]])

    grad = -np.dot(theta, A)
    logp = 0.5 * np.dot(grad, theta.T)
    return logp, grad
  • set your initial conditions: number of dimensions, _number of steps, number of adaptation/burning steps, initial guess, and initial step size.
D = 2
M = 5000
Madapt = 5000
theta0 = np.random.normal(0, 1, D)
delta = 0.2

mean = np.zeros(2)
cov = np.asarray([[1, 1.98], 
                  [1.98, 4]])
  • run the sampling (note that the tqdm module is required for full progress bar functionality):
samples, lnprob, epsilon = nuts6(correlated_normal, M, Madapt, theta0, delta, progress=True)
  • some statistics: expecting mean = (0, 0) and std = (1., 4.)
samples = samples[1::10, :]
print('Mean: {}'.format(np.mean(samples, axis=0)))
print('Stddev: {}'.format(np.std(samples, axis=0)))
  • a quick plot:
import pylab as plt
temp = np.random.multivariate_normal(mean, cov, size=500)
plt.plot(temp[:, 0], temp[:, 1], '.')
plt.plot(samples[:, 0], samples[:, 1], 'r+')
plt.show()

Example usage as an EMCEE sampler

see emcee_nuts.test_sampler

  • define a log-likelihood function:
def lnprobfn(theta):
    return correlated_normal(theta)[0]
  • define a gradient function (if not numerical estimates are made, but slower):
def gradfn(theta):
    return correlated_normal(theta)[1]
  • set your initial conditions: number of dimensions, number of steps, number of adaptation/burning steps, initial guess, and initial step size.
D = 2
M = 5000
Madapt = 5000
theta0 = np.random.normal(0, 1, D)
delta = 0.2

mean = np.zeros(2)
cov = np.asarray([[1, 1.98],
                  [1.98, 4]])
  • run the sampling:
sampler = NUTSSampler(D, lnprobfn, gradfn)
samples = sampler.run_mcmc( theta0, M, Madapt, delta )

nuts's People

Contributors

jeremysanders avatar mfkasim1 avatar mfouesneau avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

nuts's Issues

Is it mandatory to define a gradient function?

I want to use this package as an emcee sampler, but if possible I'd like to avoid having to define the gradient for the log-likelihood.

I know this is possible with the sampyl package, which uses autograd to do it, but I'm not sure if this package does something similar. The README says about the grad function: "if not numerical estimates are made, but slower", but nothing else.

Can I not define a gradient function? If so, how does NUTS handle this?

Thank you!

Installation

Hi,

I'd like to try this package out and compare it to things like emcee. However, there is no installation instructions. Ideally, a setup.py or pip installable would be great. Would you be able to add this at all?

Best,

Greg

Add license?

I'd really like to use this sampler in a project of mine (https://github.com/msmbuilder/msmbuilder). Would you consider attaching an open source license to this project? Without a license, it's a little ambiguous. Something like MIT would be ideal, but obviously it's your call.

Why some chains get stuck more than others?

Thank you for your help previously with installation. We now have a viable model up and running thanks to this repo!

One follow up question I had: why might some NUTS chains get stuck (or just slow down sampler progress) more than others? Is there some reason for example, that the gradient function might become harder to evaluate?

A question about a line of code in the function of find_reasonable_epsilon.

Thank you for publishing this nice code as it is really helpful for beginners like me to learn the NUTS in detail.

I got a question about a line of code when I was going through the source code.

k = 1.
while np.isinf(logpprime) or np.isinf(gradprime).any():
k *= 0.5
_, rprime, _, logpprime = leapfrog(theta0, r0, grad0, epsilon * k, f)

Apparently, the gradprime in the loop statement was not updated at all, which possibly causes a dead loop. So I wonder if it was a typo or I misunderstood something about this part of code?

Fix nuts6 docstring

epsilon is not an input anymore, and there's no nuts8. Also the output is a tuple, not just an array

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.