Coder Social home page Coder Social logo

pyro-ppl / funsor Goto Github PK

View Code? Open in Web Editor NEW
234.0 234.0 20.0 2.67 MB

Functional tensors for probabilistic programming

Home Page: https://funsor.pyro.ai

License: Apache License 2.0

Makefile 0.20% Python 97.15% Jupyter Notebook 2.65%
jax machine-learning numpy probabilistic-programming pyro pytorch symbolic

funsor's People

Contributors

anukaal avatar eb8680 avatar fehiepsi avatar fritzo avatar jpchen avatar martinjankowiak avatar mirca avatar neerajprad avatar ordabayevy avatar xhochy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

funsor's Issues

Implement Align and .align() funsors

We should resurrect .align() and Align, which were deleted during the v3 refactoring #17.

These are convenience funsors so that e.g. we can define

@funsor.of_shape(reals(), reals(), reals())
def normal(loc, scale, value):
    return -((value - loc) ** 2) / (2 * scale ** 2) - scale.log() -
        math.log(math.sqrt(2 * math.pi))

and ensure its inputs are ordered (loc, scale, value) rather than (value, loc, scale).

Tasks

First PR:

  • Implement a simple Align to simply re-order its inputs.
  • Align.eager_*() methods do not need to preserve alignment; they can simply defer to self.arg.eager_*().
  • Make Funsor.align() call Align
  • Use .align() in of_funsor
  • Add Funsor.align() tests to test_terms.py

Second PR:

  • Make Tensor.align() call torch.permute() under the hood
  • Add Tensor.align() tests to test_torch.py

Make it easy to write pyro programs that mix PyTorch and Funsor

We plan to start using Funsor for delayed sampling in Pyro. Pyro programs should work with and without delayed sampling (i.e. with any choice of inference). To achieve this mixture of PyTorch and Funsor, we will need some helpers:

  • funsor.ops that dispatch on type
  • to_funsor to convert things to funsor
  • #84 to_data to convert things to non-funsors
  • helpers to dispatch @funsor.torch.function to either PyTorch or Funsor
    (unit test added in #118 )

Triaged:

  • pyro.ground to eliminate delayed variables

Implement recognition of affine transforms

To perform full exact inference of Gaussian distributions, we will need some ability to recognize affine transforms. This might either rely on autograd, or could simply pattern match +, *, and einsum patterns.

Tasks

  • Handle Binary and Variable patterns #116
  • Rewrite Normal(Affine, ...) to an appropriate Gaussian #119
  • Handle Contract and einsum #157
  • Recognize affine funsors including Contraction and Einsum #249
  • Recognize GetitemOp and ReshapeOp as a linear transformations #249
  • Rewrite MultivariateNormal(Affine, ...) to an appropriate Gaussian #245
  • Fix funsor.affine.is_affine() to be sound but incomplete #282
  • Rewrite Gaussian(x=affine) as a Gaussian. In three steps: #284, #285, #286

Nondeterministic test failure in test_einsum.py

I suspect this is due to nondeterminism somewhere, like use of a dict rather than an OrderedDict.

==== FAILURES ====
____ test_einsum[a,ab,bc,cd->] ____

equation = 'a,ab,bc,cd->'

    @pytest.mark.parametrize('equation', EINSUM_EXAMPLES + XFAIL_EINSUM_EXAMPLES)
    def test_einsum(equation):
        inputs, outputs, operands, sizes = make_example(equation)
        funsor_operands = [
            funsor.Tensor(operand, OrderedDict([(d, funsor.bint(sizes[d])) for d in inp]))
            for inp, operand in zip(inputs, operands)
        ]
        expected = torch.einsum(equation, operands)
        actual = naive_einsum(equation, *funsor_operands)
        assert expected.shape == actual.data.shape
>       assert torch.allclose(expected, actual.data)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x11953f470>(tensor(-0.0174), tensor(-0.0174))
E        +    where <built-in method allclose of type object at 0x11953f470> = torch.allclose
E        +    and   tensor(-0.0174) = Tensor(-0.01743006706237793, OrderedDict(), 'real').data

test/test_einsum.py:76: AssertionError
==== 1 failed, 785 passed, 9 skipped, 16 xfailed in 3.28 seconds ====
make: *** [test] Error 1

Opt_einsum-based optimizer cannot handle Normal variable elimination

One thing blocking the use of the opt_einsum-based optimizer for continuous variable elimination is that Normal currently implements only a limited contraction interface:

  • .reduce(ops.logaddexp, self.value) which just returns 0 = log(1) because it's normalized, and
  • .binary(ops.add, x) where x is ground wrt the current contraction.

As a workaround we're using a tree_engine that performs Birch-style variable elimination on trees, starting from observations (leaves) and moving backwards. While this will work to demo the Kalman filter example, it is not a long-term solution.

Possible fixes

  • try to support more general x in Normal.binary()
  • support direction hinting in the opt_einsum optimizer

Implement a Joint funsor as a normal form / state machine

~~Blocked by #51 ~~

It appears that the simplest way to implement an inference strategy is via a Joint funsor that keeps track of various Tensor, Gaussian, and Delta factors in a canonical, searchable data structure. This Joint object is really a state machine / interpreter disguised as a Funsor. Its denotational semantics are defined by its .eager_subs() method, and its operational semantics are defined by the transition operations we call on Joint objects during inference:

  • += a log prob term with either a lazy or eager value
  • .reduce(ops.logaddexp, -) out variables when they are no longer live
  • .reduce(ops.sample, -) variables when they need a ground value (e.g. for control flow)
  • .reduce(ops.add, -) out plates when they are complete (see log_joint in minipyro.py)

Note that this can be sensibly implemented even before we implement sampling #52

Tasks

  • Implement a Joint funsor with Tensor, Gaussian, and various Delta terms
  • Implement += Tensor transition
  • Implement += Gaussian transition
  • Implement += Delta transition
  • Implement .reduce(ops.logaddexp, -) to eliminate random variables
  • Implement .reduce(ops.sample, -) to sample random variables
  • Implement .reduce(ops.add, -) to eliminate plates variables
  • Implement += Joint to fuse two state machines
  • Add tests
  • Use Joint in examples

Add SMC example

We can use pyro.markov to trigger an optional resampling step based on perplexity. This behavior would be handled by an inference handler.

Outline

  • Add storyboard example using pyro.markov to perform SMC
  • Implement SMC handler to interpret pyro.sample statements as multi-particle samples
  • Implement resampling by interpreting pyro.markov statements

Add HMC and NUTS algorithms / interpretations

Blocked by Pyro HMC refactoring pyro-ppl/pyro#1816

Initially we can simply wrap Pyro's HMC and NUTS implementations. After NumPyro's distributions are closer to torch.distributions and after Funsor's Numpy support improves, we can also implement a NumPyro-based HMC and NUTS.

Tasks

  • Add shim for Pyro's HMC and NUTS
  • Add an examples/nuts.py

Implement an Expectation funsor

Expectation is combines ideas of TraceEnum_ELBO's Dice and Funsor's Integrate and plate:

Expectation((p1, p2, ..., pn),  # log measures
            (c1, c2, ..., cn),  # costs
            sum_vars=frozenset(...),
            prod_vars=frozenset(...))

The optimizer should perform one instance of Tensor Variable Elimination for each cost term (against all p terms).

If prod_vars is empty (i.e. no plates), then

Expectation((p1, p2, ..., pn),
            (c1, c2, ..., cn),
            sum_vars=sum_vars,
            prod_vars=frozenset())  # empty
  = Integrate(p1 + p2 + ... + pn,
              c1 + c2 + ... + cn,
              reduced_vars=sum_vars)
  = Contract(p1.exp() * p2.exp() * ... * pn.exp(),
             c1 + c2 + ... + cn,
             reduced_vars=sum_vars)
  = (p1.exp() * p2.exp() * ... * pn.exp()
     * (c1 + c2 + ... + cn)).reduce(ops.add, sum_vars)

If there prod_vars is not empty, we could naively evaluate via TVE:

Expectation(log_probs, costs, sum_vars, prod_vars)
  = sum(sum_product(ops.add, ops.mul,
                    [p.exp() for p in log_probs] + [c],
                    sum_vars, prod_vars)
        for c in costs)

This can be used in the elbo computation in minipyro.

Fix delayed semisupervised vae example

Let's try to make examples/ss_vae_delayed.py do actual interesting work.

Outline

  • Make it work using simple eager-sampling inference #95
  • Get the example working with delayed inference (but with quadratic growth)
  • Implement a moment-matching inference helper to combine samples of loc,scale #115 #90
  • Get the example working with delayed inference + moment-matching approximation

Implement Beta-Binomial conjugacy

We should be able to implement Beta-Binomial conjugacy via an operation Beta.binary(ops.add, -).

Outline

  • #120 Implement Beta and Binomial distributions
  • Implement conjugacy relationship in Joint.eager_reduce()
  • Add a Beta-Binomial example in examples

Implement a Contract op as a fused Reduce(Binary or Finitary)

Funsor currently implements variable elimination entirely with sum and product ops, but for efficiency these should be fused into a GEMM-like Contract op for at least binary contractions.

We could support finitary contractions, but binary contractions are very basic: they are GEMM-like ops, inner products in Hilbert spaces, the basic tensor contraction of cutensor, and the basic op in an optimized variable elimination schedule.

Add examples implementing various PPL styles via Funsor

@lawmurray suggested adding examples of how to implement different styles of probabilistic programming languages on top of Funsor. We could implement this as a directory examples/ppl-zoo or examples/rosetta or something.

Tasks

(@eb8680 @lawmurray I've tried to populate with representative examples, but I think you both have clearer views of the landscape; feel free to update/add to/remove from this list)

  • a Pyro example (minipyro)
  • a Birch-style example (with delayed sampling, maybe using pyro.ground)
  • a Stan-style example (HMC on a non-normalized log-density function, blocked by #123)
  • a TFP JointDistributionCoroutine-style example wrapping minipyro with yield statements
  • an omega/Edward2-style example (with RandomVariables (and RCD?)) #130
  • an Infer.net-style example using expectation propagation #90 or #115
  • a webPPL-style example (a single Infer operator and basic ADTs and control flow)
  • a Dyna/kProbLog-style example (with a polyvariadic fix-point combinator)
  • a Church-style example
  • a Venture/Gen-style example

Fix optimizer objective function for Gaussian variable elimination

Whereas the space complexity of a joint categorical distribution grows exponentially in the number of dimensions (aka variables), joint Gaussians grow only quadratically. So for example the tractability theorem is weaker for Gaussian models than discrete models. This is relevant for funsor.optimizer because opt_einsum's greedy objective is memory size assuming exponential growth (in compute_size_by_dict).

We should send a fix to opt_einsum to support a more general flop_count() function, possibly supporting non-integer sizes in their shapes. If this change is too intrusive upstream, we might simply fork their path optimizer for use in finsor. However it would be preferable to keep funsor simple and defer to opt_einsum wherever possible.

Add discrete HMM example

A simple discrete eager HMM without plates seems like the simplest end-to-end example we could get working that would exercise lots of funsor machinery at once.

Tasks

  • Write a complete but xfailing example script that we can develop against
  • Figure out what's missing/broken, especially in distributions

Add sensor fusion example

It would be nice to have an example of sensor fusion or multiple target tracking, where we approximate a joint distribution on an unknown number of objects via a multi-Bernoulli process. The idea is to accumulate knowledge by sequentially updating a joint hypothesis density from observations that are frames or collections of single-object observations.

See also pyro.contrib.tracking.

Clean up substitution

#148 introduces several major changes to substitution as part of the implementation of alpha-conversion. These changes should allow us to further simplify a lot of the old substitution logic:

  • #152 Remove eager_subs methods from all terms (e.g. Binary, Joint) that do not introduce fresh variables
  • #152 Remove conditional xfails and extra test runs added temporarily as part of #148
  • #153 Factor out non-fresh substitution logic from remaining eager_subs methods (e.g in Delta) so that terms only have to be responsible for substituting into their fresh variables
  • #153 Make Subs compatible with FUNSOR_DEBUG=1 again
  • #154 Update Joint and associated patterns so that Joint.eager_subs can be removed
  • #155 Make the new substitute function into an interpreter, allowing tail-call optimized substitution via #145

Add funsor.numpy backend

Currently funsor.torch provides a PyTorch backend via Tensor, @function and a few other helpers. It should be straight-forward to provide similar adaptors for numy and jax., say as a new module funsor.numpy and class Array.

Outline

  • #74 support numpy in pyro.ops. e.g. we will need to call np.log(x) rather than x.log() based on type.
  • implement an align_arrays function similar to align_tensors
  • implement a funsor.numpy.Array class
  • register np.array with to_funsor
  • fork test_torch.py to a new test_numpy.py
  • Add compatibility with tail-call optimized interpreter from #145 and remove conditional xfail added to test_numpy.py added in #155
  • implement Array.eager_reduce()
  • optionally add a Function wrapper and @funsor.numpy.function decorator
  • refactor (Tensor,Array) to share a superclass, e.g. Tensor -> {TorchTensor, NumpyTensor)
  • refactor Gaussian to Tensor rather than torch.Tensor internally.
    This requires wrapping many linear algebra operations as functions, e.g. cholesky_solve.
  • Make unary ops work with Array input. (#303)
  • Enable test_tensor.py with FUNSOR_USE_TCO=1
  • Support numpy docstring in funsor.util.getargspec
  • Add numpy backend tests for test_gaussian.
  • Add numpy backend for einsum.
  • Consider disabling some of the overhead caused by normalize.
  • Make funsor.distributions backend-agnostic
  • Add support for jax.numpy.
  • update funsor.minipyro to support numpy backend
  • add inference examples using numpy

[example] Using funsor for active inference

Early this year, I was interested in using Pyro for active inference (but my TODO list is quite long so I set it aside). Models based on active inference are quite similar to HMM. An interesting application of active inference is in reinforcement learning, where we keep getting feedbacks (observations) from environment and improve the model/guide to make it best reflecting these observations. By setting desired priors on observation nodes, we can train the model/guide to achieve a goal without having to use rewards as in RL. Because it is just an extended version of HMM, I think that funsor is a good tool for it.

Here are some references which I found clearest to understand the theory behind:

The third one is doing inference by message passing on a bipartite graph, which seems closest to funsor. I don't know what is the current state of funsor so I'll try to go step-by-step to understand funsor more. Maybe I can contribute to funsor a bit during the way. But it would be great help if some of you are also interested in this direction and we can discuss more about it. :)

cc @eb8680 @fritzo @neerajprad

Fix Kalman filter example

examples/kalman_filter.py and test/test_engine.py::test_hmm_gaussian_gaussian are currently failing with Normal simplification errors

Implement custom sum-product optimizer

Tensor variable elimination in Pyro depends on opt_einsum and its tensor contraction optimizer. We initially wanted to reuse their optimizer, but the assumptions it makes about the input contraction (particularly the assumption that any pair of terms or intermediate results can be contracted together) are not ideal for rewriting funsor sum-product expressions (#2 ) or may leave exploitable structure on the table.

This issue is for tracking the implementation of a new Funsor-specific sum-product optimizer. Perhaps we can initially adapt tree_engine into an optimizer?

Add backend dispatchers for operations in funsor.ops

It seems to me that the simplest way to support multiple backends in funsor.ops would be to use multiple dispatch as follows:

@singledispatch
def abs(x):
    return _builtin_abs(x)

@abs.register(torch.Tensor)
def abs_torch(x):
    return x.abs()

@abs.register(np.ndarray)
def abs_np(x):
    return np.abs(x)

What do you think? Will that cause any inadvertent complications otherwise?

Add plated einsum example

Add an example implementing plated einsum via funsors. Funsors are arguably a cleaner DSL for expressing these kinds of dynamic programs, and the example could demonstrate both how to implement general plated einsum, and how to express various specific problems using funsor syntax rather than einsum syntax.

Outline

  • add a non-plated einsum example
  • add plated einsum implementation

Implement a convolution operation ops.conv

Convolution is the basic operation in Gaussian process kernel construction. Convolution only applies to real variables.

Tasks

  • add ops.conv
  • register Delta-Delta
  • register Delta-Gaussian
  • register Gaussian-Gaussian

Implement a mixture interpretation

Joint distributions with continuous parts implement only lazy reduction wrt mixture components, as encoded in the reduction rule

|- x                  i |- x + Joint(y) -> Joint(y')
--------------------------------------------------------------------------- eager
|- x + Reduce(logaddexp, Joint(y), {i}) -> Reduce(logaddexp, Joint(y), {i})

This results in exponential growth of the number of mixture components in e.g. switching linear dynamical systems.

We could perform inference by an approximating interpretation that reduces high-dimensional mixtures with e.g. mixtures of a fixed number of components

                  num_elements(S) > num_elements({k})
--------------------------------------------------------------------- mixture(k)
|- Reduce(logaddexp, Joint(y), S) -> Reduce(logaddexp, Joint(y), {k})

Questions

  • What approximation strategy should we use?
    If minimizing KL(p||q) we could defer to expectation propagation #90.

imm derivation

Implement unbiased wrappers for biased interpretations

We should be able to wrap a biased approximate interpretation in an unbiased interpretation using rejection sampling with a biased proposal. For example we should be able to use moment_matching #115 , linearize #89 , and EP #90 as proposal distributions.

Kevin Murphy points out:

I think these fast, deterministic locally approximating methods [like moment matching and Gaussian approximation] could be really useful as proposals for MC methods. Here is the paper by Nando that I mentioned (I am sure there are many more papers like this :)

R. van der Merwe, A. Doucet, N. de Freitas, and E. A. Wan, “The Unscented Particle Filter,” in Advances in Neural Information Processing Systems 13, T. K. Leen, T. G. Dietterich, and V. Tresp, Eds. MIT Press, 2001, pp. 584–590 [Online]. Available: http://papers.nips.cc/paper/1818-the-unscented-particle-filter.pdf

Fix funsor.minipyro and add tests

funsor/minipyro.py is currently only a storyboard: it does not function but does serve to demonstrate syntax and computations we'd like to support in funsor.

Tasks

  • add tests for eager sampling
  • add tests for delayed sampling
  • add tests for trace
  • add tests for replay
  • add tests for block
  • add tests for pyro.plate
  • add tests for log_joint
  • add tests for barrier
  • add tests for ground

Implement Monte Carlo interpretation of .logaddexp()

It appears that the natural way to express Monte Carlo sampling in funsor is with a nonstandard interpretation of .logaddexp() reductions. (Note that for numerical stability, we represent densities in log-space; hence expectation is represented by .logaddexp() reductions rather than .sum() reductions).

To perform sampling on factor graphs, we will orthogonally need #50.

Tasks

  • #75 Implement monte carlo .sample() methods
  • Implement a funsor Integrate(log_measure, integrand, reduced_vars)
  • Use Integrate() in test_samplers.py to test Gaussian moments
  • Implement a monte_carlo interpretation that uses samplers in Integrate()

Support plates and convolution dims in Funsors?

@eb8680 suggested adding metadata to Domain objects or .input dicts to declare which inputs are plate indices. This is similar to @jpchen's suggestion of a Cat funsor, but would also support cat-broadcasting. From a mathematical perspective, adding funsors with different plate indices would be a direct sum.

I believe the best way to implement this is to first implement correct math in funsor.minipyro.log_joint and then to refactor it into Domain or Funsor.

Tasks

  • #56 sketch plate math in funsor.minipyro.log_joint
  • #57 fix minipyro and add tests
  • #109 add plate-aware contract/expectation ops
  • Add a funsor Conv(x, time_dim, curr_dim, next_dim, new_dim) satisfying
    Conv(f, "time", "curr", "next", "curr")(curr=x)
        = f(curr=x(time=arange("time", f.inputs["time"].dtype - 1)),
            next=x(time=arange("time", f.inputs["time"].dtype - 1) + 1))
  • #161 Implement log(time) convolution op following this paper
  • #292 Generalize sequential_sum_product to handle longer time lags
  • Support time dims in partial_sum_product (see convolve-method branch)
  • Add a state space example using Conv applied to a contiguous stacked nn output

Implement alpha conversion logic

Many places in funsor raise NotImplementedError when a bound variable appears free in an expression being substituted. This should be resolved by alpha-converting the binder's expression before substitution (i.e. renaming the bound variable).

@eb8680 suggested handling this uniformly by refactoring e.g. the Funsor base class to know about subterms and bound variables. Probably the easiest way to resolve this issue is to refactor first.

Tasks

  • Refactor to let Funsor base class know about variable binding and subterms
  • Implement alpha conversion once in the Funsor base class
  • Make compatible with funsor.adjoint

Add a forward-filter-backward-sample example

This example should demonstrate how to use funsor.adjoint to perform exact inference on a foward pass, followed by monte carlo on the backward pass. A discrete HMM example would be nice.

Blocked by #164

Memoize doesn't work correctly as a handler?

Memoize was recently made the non-default behavior and moved into handlers.py. However IIUC it cannot correctly be used as a handler:

  1. it would always need to be at the top of the handler stack, i.e. be applied last
  2. if any other handler is applied first, Memoize would not be called when constructing Funsor objects.

Am I missing something? I think we may want to simply move it back to being the default behavior (I cannot imagine a circumstance where we would not want to memoize).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.