pyro-ppl / funsor Goto Github PK
View Code? Open in Web Editor NEWFunctional tensors for probabilistic programming
Home Page: https://funsor.pyro.ai
License: Apache License 2.0
Functional tensors for probabilistic programming
Home Page: https://funsor.pyro.ai
License: Apache License 2.0
We should resurrect .align() and Align, which were deleted during the v3 refactoring #17.
These are convenience funsors so that e.g. we can define
@funsor.of_shape(reals(), reals(), reals())
def normal(loc, scale, value):
return -((value - loc) ** 2) / (2 * scale ** 2) - scale.log() -
math.log(math.sqrt(2 * math.pi))
and ensure its inputs are ordered (loc, scale, value)
rather than (value, loc, scale)
.
First PR:
Align
to simply re-order its inputs.Align.eager_*()
methods do not need to preserve alignment; they can simply defer to self.arg.eager_*()
.Funsor.align()
call Align
.align()
in of_funsor
Funsor.align()
tests to test_terms.pySecond PR:
Tensor.align()
call torch.permute()
under the hoodTensor.align()
tests to test_torch.pyWe plan to start using Funsor for delayed sampling in Pyro. Pyro programs should work with and without delayed sampling (i.e. with any choice of inference). To achieve this mixture of PyTorch and Funsor, we will need some helpers:
funsor.ops
that dispatch on typeto_funsor
to convert things to funsorto_data
to convert things to non-funsors@funsor.torch.function
to either PyTorch or FunsorTriaged:
pyro.ground
to eliminate delayed variablesTo perform full exact inference of Gaussian distributions, we will need some ability to recognize affine transforms. This might either rely on autograd, or could simply pattern match +
, *
, and einsum patterns.
Binary
and Variable
patterns #116Normal(Affine, ...)
to an appropriate Gaussian
#119Contract
and einsum
#157Contraction
and Einsum
#249GetitemOp
and ReshapeOp
as a linear transformations #249MultivariateNormal(Affine, ...)
to an appropriate Gaussian
#245funsor.affine.is_affine()
to be sound but incomplete #282Gaussian(x=affine)
as a Gaussian
. In three steps: #284, #285, #286I suspect this is due to nondeterminism somewhere, like use of a dict rather than an OrderedDict
.
==== FAILURES ====
____ test_einsum[a,ab,bc,cd->] ____
equation = 'a,ab,bc,cd->'
@pytest.mark.parametrize('equation', EINSUM_EXAMPLES + XFAIL_EINSUM_EXAMPLES)
def test_einsum(equation):
inputs, outputs, operands, sizes = make_example(equation)
funsor_operands = [
funsor.Tensor(operand, OrderedDict([(d, funsor.bint(sizes[d])) for d in inp]))
for inp, operand in zip(inputs, operands)
]
expected = torch.einsum(equation, operands)
actual = naive_einsum(equation, *funsor_operands)
assert expected.shape == actual.data.shape
> assert torch.allclose(expected, actual.data)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x11953f470>(tensor(-0.0174), tensor(-0.0174))
E + where <built-in method allclose of type object at 0x11953f470> = torch.allclose
E + and tensor(-0.0174) = Tensor(-0.01743006706237793, OrderedDict(), 'real').data
test/test_einsum.py:76: AssertionError
==== 1 failed, 785 passed, 9 skipped, 16 xfailed in 3.28 seconds ====
make: *** [test] Error 1
Necessary for evaluating deeply nested expressions without hitting the Python interpreter stack limit, like marginal likelihoods for time series models
One thing blocking the use of the opt_einsum-based optimizer for continuous variable elimination is that Normal
currently implements only a limited contraction interface:
.reduce(ops.logaddexp, self.value)
which just returns 0 = log(1)
because it's normalized, and.binary(ops.add, x)
where x is ground wrt the current contraction.As a workaround we're using a tree_engine
that performs Birch-style variable elimination on trees, starting from observations (leaves) and moving backwards. While this will work to demo the Kalman filter example, it is not a long-term solution.
x
in Normal.binary()
~~Blocked by #51 ~~
It appears that the simplest way to implement an inference strategy is via a Joint
funsor that keeps track of various Tensor
, Gaussian
, and Delta
factors in a canonical, searchable data structure. This Joint
object is really a state machine / interpreter disguised as a Funsor. Its denotational semantics are defined by its .eager_subs()
method, and its operational semantics are defined by the transition operations we call on Joint
objects during inference:
+=
a log prob term with either a lazy or eager value.reduce(ops.logaddexp, -)
out variables when they are no longer live.reduce(ops.sample, -)
variables when they need a ground value (e.g. for control flow).reduce(ops.add, -)
out plates when they are complete (see log_joint in minipyro.py)Note that this can be sensibly implemented even before we implement sampling #52
Joint
funsor with Tensor
, Gaussian
, and various Delta
terms+= Tensor
transition+= Gaussian
transition+= Delta
transition.reduce(ops.logaddexp, -)
to eliminate random variables.reduce(ops.sample, -)
to sample random variables.reduce(ops.add, -)
to eliminate plates variables+= Joint
to fuse two state machinesJoint
in examplesWe can use pyro.markov
to trigger an optional resampling step based on perplexity. This behavior would be handled by an inference handler.
pyro.markov
to perform SMCpyro.sample
statements as multi-particle samplespyro.markov
statementsBlocked by Pyro HMC refactoring pyro-ppl/pyro#1816
Initially we can simply wrap Pyro's HMC and NUTS implementations. After NumPyro's distributions are closer to torch.distributions and after Funsor's Numpy support improves, we can also implement a NumPyro-based HMC and NUTS.
Expectation is combines ideas of TraceEnum_ELBO's Dice
and Funsor's Integrate
and plate
:
Expectation((p1, p2, ..., pn), # log measures
(c1, c2, ..., cn), # costs
sum_vars=frozenset(...),
prod_vars=frozenset(...))
The optimizer should perform one instance of Tensor Variable Elimination for each cost term (against all p terms).
If prod_vars
is empty (i.e. no plates), then
Expectation((p1, p2, ..., pn),
(c1, c2, ..., cn),
sum_vars=sum_vars,
prod_vars=frozenset()) # empty
= Integrate(p1 + p2 + ... + pn,
c1 + c2 + ... + cn,
reduced_vars=sum_vars)
= Contract(p1.exp() * p2.exp() * ... * pn.exp(),
c1 + c2 + ... + cn,
reduced_vars=sum_vars)
= (p1.exp() * p2.exp() * ... * pn.exp()
* (c1 + c2 + ... + cn)).reduce(ops.add, sum_vars)
If there prod_vars
is not empty, we could naively evaluate via TVE:
Expectation(log_probs, costs, sum_vars, prod_vars)
= sum(sum_product(ops.add, ops.mul,
[p.exp() for p in log_probs] + [c],
sum_vars, prod_vars)
for c in costs)
This can be used in the elbo computation in minipyro.
Let's try to make examples/ss_vae_delayed.py do actual interesting work.
Blocked for now by #51 We may also need a TruncatedNormal
distribution.
This issue proposes to add versions of the examples in section 6 from Discrete-Continuous Mixtures in Probabilistic Programming: Generalized Semantics and Inference Algorithms (Wu et al. 2018)
We should be able to implement Beta-Binomial
conjugacy via an operation Beta.binary(ops.add, -)
.
Beta
and Binomial
distributionsJoint.eager_reduce()
Funsor currently implements variable elimination entirely with sum and product ops, but for efficiency these should be fused into a GEMM-like Contract op for at least binary contractions.
We could support finitary contractions, but binary contractions are very basic: they are GEMM-like ops, inner products in Hilbert spaces, the basic tensor contraction of cutensor, and the basic op in an optimized variable elimination schedule.
The math is just wrong right now
@lawmurray suggested adding examples of how to implement different styles of probabilistic programming languages on top of Funsor. We could implement this as a directory examples/ppl-zoo or examples/rosetta or something.
(@eb8680 @lawmurray I've tried to populate with representative examples, but I think you both have clearer views of the landscape; feel free to update/add to/remove from this list)
pyro.ground
)yield
statementsRandomVariable
s (and RCD?)) #130Infer
operator and basic ADTs and control flow)Whereas the space complexity of a joint categorical distribution grows exponentially in the number of dimensions (aka variables), joint Gaussians grow only quadratically. So for example the tractability theorem is weaker for Gaussian models than discrete models. This is relevant for funsor.optimizer
because opt_einsum's greedy objective is memory size assuming exponential growth (in compute_size_by_dict).
We should send a fix to opt_einsum to support a more general flop_count()
function, possibly supporting non-integer sizes in their shapes. If this change is too intrusive upstream, we might simply fork their path optimizer for use in finsor. However it would be preferable to keep funsor simple and defer to opt_einsum wherever possible.
Currently the optimizer handles only non-plated variable elimination, but we'd like to additionally support tensor variable elimination.
@jacobrgardner pointed out today that GPyTorch implements heavily optimized, numerically stable versions of the abstract Gaussian operations added in #37. We should look into using those to implement the funsor.Gaussian
interface.
A simple discrete eager HMM without plates seems like the simplest end-to-end example we could get working that would exercise lots of funsor machinery at once.
Tasks
distributions
It would be nice to have an example of sensor fusion or multiple target tracking, where we approximate a joint distribution on an unknown number of objects via a multi-Bernoulli process. The idea is to accumulate knowledge by sequentially updating a joint hypothesis density from observations that are frames or collections of single-object observations.
See also pyro.contrib.tracking.
The funsor interface currently does not handle multivariate distributions. We plan to completely redesign the basic Funsor
class to allow more complex shapes.
See the v2 design doc for details.
This is the first handler that is parametrized by another handler.
This is required for efficient use of nn.Modules with multiple outputs #79 .
See design doc for details.
#148 introduces several major changes to substitution as part of the implementation of alpha-conversion. These changes should allow us to further simplify a lot of the old substitution logic:
eager_subs
methods from all terms (e.g. Binary
, Joint
) that do not introduce fresh variableseager_subs
methods (e.g in Delta
) so that terms only have to be responsible for substituting into their fresh variablesSubs
compatible with FUNSOR_DEBUG=1
againJoint
and associated patterns so that Joint.eager_subs
can be removedsubstitute
function into an interpreter, allowing tail-call optimized substitution via #145Currently funsor.torch
provides a PyTorch backend via Tensor
, @function
and a few other helpers. It should be straight-forward to provide similar adaptors for numy and jax., say as a new module funsor.numpy
and class Array
.
pyro.ops
. e.g. we will need to call np.log(x)
rather than x.log()
based on type.align_arrays
function similar to align_tensors
funsor.numpy.Array
classnp.array
with to_funsor
test_torch.py
to a new test_numpy.py
test_numpy.py
added in #155Array.eager_reduce()
Function
wrapper and @funsor.numpy.function
decoratorTensor
,Array
) to share a superclass, e.g. Tensor
-> {TorchTensor
, NumpyTensor
)Gaussian
to Tensor
rather than torch.Tensor internally.cholesky_solve
.test_tensor.py
with FUNSOR_USE_TCO=1funsor.util.getargspec
test_gaussian
.einsum
.normalize
.funsor.distributions
backend-agnosticjax.numpy
.funsor.minipyro
to support numpy backendEarly this year, I was interested in using Pyro for active inference (but my TODO list is quite long so I set it aside). Models based on active inference are quite similar to HMM. An interesting application of active inference is in reinforcement learning, where we keep getting feedbacks (observations) from environment and improve the model/guide to make it best reflecting these observations. By setting desired priors on observation nodes, we can train the model/guide to achieve a goal without having to use rewards as in RL. Because it is just an extended version of HMM, I think that funsor
is a good tool for it.
Here are some references which I found clearest to understand the theory behind:
The third one is doing inference by message passing on a bipartite graph, which seems closest to funsor
. I don't know what is the current state of funsor
so I'll try to go step-by-step to understand funsor
more. Maybe I can contribute to funsor
a bit during the way. But it would be great help if some of you are also interested in this direction and we can discuss more about it. :)
examples/kalman_filter.py and test/test_engine.py::test_hmm_gaussian_gaussian are currently failing with Normal simplification errors
There's a prototype pair-coded by @fritzo and @eb8680 in this branch, but it needs to be updated and documented. We might also want a polyvariadic version to express mutual recursion.
Tensor variable elimination in Pyro depends on opt_einsum
and its tensor contraction optimizer. We initially wanted to reuse their optimizer, but the assumptions it makes about the input contraction (particularly the assumption that any pair of terms or intermediate results can be contracted together) are not ideal for rewriting funsor sum-product expressions (#2 ) or may leave exploitable structure on the table.
This issue is for tracking the implementation of a new Funsor-specific sum-product optimizer. Perhaps we can initially adapt tree_engine
into an optimizer?
It seems to me that the simplest way to support multiple backends in funsor.ops
would be to use multiple dispatch as follows:
@singledispatch
def abs(x):
return _builtin_abs(x)
@abs.register(torch.Tensor)
def abs_torch(x):
return x.abs()
@abs.register(np.ndarray)
def abs_np(x):
return np.abs(x)
What do you think? Will that cause any inadvertent complications otherwise?
Add an example implementing plated einsum via funsors. Funsors are arguably a cleaner DSL for expressing these kinds of dynamic programs, and the example could demonstrate both how to implement general plated einsum, and how to express various specific problems using funsor syntax rather than einsum syntax.
Convolution is the basic operation in Gaussian process kernel construction. Convolution only applies to real variables.
ops.conv
Delta
-Delta
Delta
-Gaussian
Gaussian
-Gaussian
Joint distributions with continuous parts implement only lazy reduction wrt mixture components, as encoded in the reduction rule
|- x i |- x + Joint(y) -> Joint(y')
--------------------------------------------------------------------------- eager
|- x + Reduce(logaddexp, Joint(y), {i}) -> Reduce(logaddexp, Joint(y), {i})
This results in exponential growth of the number of mixture components in e.g. switching linear dynamical systems.
We could perform inference by an approximating interpretation that reduces high-dimensional mixtures with e.g. mixtures of a fixed number of components
num_elements(S) > num_elements({k})
--------------------------------------------------------------------- mixture(k)
|- Reduce(logaddexp, Joint(y), S) -> Reduce(logaddexp, Joint(y), {k})
KL(p||q)
we could defer to expectation propagation #90.One approach to approximate inference is expectation propagation. Can we implement EP as an interpretation?
We should be able to wrap a biased approximate interpretation in an unbiased interpretation using rejection sampling with a biased proposal. For example we should be able to use moment_matching
#115 , linearize
#89 , and EP #90 as proposal distributions.
Kevin Murphy points out:
I think these fast, deterministic locally approximating methods [like moment matching and Gaussian approximation] could be really useful as proposals for MC methods. Here is the paper by Nando that I mentioned (I am sure there are many more papers like this :)
R. van der Merwe, A. Doucet, N. de Freitas, and E. A. Wan, “The Unscented Particle Filter,” in Advances in Neural Information Processing Systems 13, T. K. Leen, T. G. Dietterich, and V. Tresp, Eds. MIT Press, 2001, pp. 584–590 [Online]. Available: http://papers.nips.cc/paper/1818-the-unscented-particle-filter.pdf
funsor/minipyro.py is currently only a storyboard: it does not function but does serve to demonstrate syntax and computations we'd like to support in funsor.
trace
replay
block
pyro.plate
log_joint
barrier
ground
The elbo implementation in funsor/minipyro.py is currently broken.
It appears that the natural way to express Monte Carlo sampling in funsor is with a nonstandard interpretation of .logaddexp()
reductions. (Note that for numerical stability, we represent densities in log-space; hence expectation is represented by .logaddexp()
reductions rather than .sum()
reductions).
To perform sampling on factor graphs, we will orthogonally need #50.
.sample()
methodsIntegrate(log_measure, integrand, reduced_vars)
Integrate()
in test_samplers.py to test Gaussian
momentsmonte_carlo
interpretation that uses samplers in Integrate()
@eb8680 suggested adding metadata to Domain
objects or .input
dicts to declare which inputs are plate indices. This is similar to @jpchen's suggestion of a Cat
funsor, but would also support cat-broadcasting. From a mathematical perspective, adding funsors with different plate indices would be a direct sum.
I believe the best way to implement this is to first implement correct math in funsor.minipyro.log_joint
and then to refactor it into Domain
or Funsor
.
funsor.minipyro.log_joint
Conv(x, time_dim, curr_dim, next_dim, new_dim)
satisfying
Conv(f, "time", "curr", "next", "curr")(curr=x)
= f(curr=x(time=arange("time", f.inputs["time"].dtype - 1)),
next=x(time=arange("time", f.inputs["time"].dtype - 1) + 1))
sequential_sum_product
to handle longer time lagspartial_sum_product
(see convolve-method branch)Conv
applied to a contiguous stacked nn outputMany places in funsor raise NotImplementedError
when a bound variable appears free in an expression being substituted. This should be resolved by alpha-converting the binder's expression before substitution (i.e. renaming the bound variable).
@eb8680 suggested handling this uniformly by refactoring e.g. the Funsor
base class to know about subterms and bound variables. Probably the easiest way to resolve this issue is to refactor first.
Funsor
base class know about variable binding and subtermsFunsor
base classfunsor.adjoint
Currently weak simplification logic is baked into constructor methods of various Funsor types, e.g. .binary()
, .__call__()
. We'll want to move these to a handler, either Eager
or Simplify
.
This example should demonstrate how to use funsor.adjoint
to perform exact inference on a foward pass, followed by monte carlo on the backward pass. A discrete HMM example would be nice.
Blocked by #164
Memoize
was recently made the non-default behavior and moved into handlers.py
. However IIUC it cannot correctly be used as a handler:
Memoize
would not be called when constructing Funsor
objects.Am I missing something? I think we may want to simply move it back to being the default behavior (I cannot imagine a circumstance where we would not want to memoize).
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.