One benefit of Rust's trait and generics system is that powerful (probabilistic) language extensions can be defined in the source language as procedural macros, enabling system's level probabilistic architecture at low abstraction cost. Ferric for instance contains a successful implementation of Rust blocks with sampling, conditioning, and query statements.
It should be feasible to generalize this motif to include a domain specific language for generative functions using the syn
crate and procedural macros. A first pass would omit non-sampling statements and leverage the ChoiceHashMap
as the ChoiceBuffer
representation.
Proposed Syntax
First-order Generative Functions
Takes in any number of non-generative user or extern crate structs as arguments.
use types_2d::*;
use gen_rs_macros::gen;
gen!{fn pointed_model(bounds: Bounds, obs_cov: DMatrix<f64>) -> Point {
let latent ~ uniform_2d(&bounds);
let obs = {"observation"} ~ mvnormal(&latent, &obs_cov);
obs
}};
Generated Rust
TODO
Higher-order Generative Functions
A generative function that takes in a trace &U where {U: Trace}
as the first argument is defined as a higher-order generative function (HOGF). HOGFs in gen-rs
are compatible with proposal-based MCMC. The reference operator can be implicitly translated into the Weak<U>
type with inline (memory-safe) Weak::upgrade
semantics. This prevents reference cycles while allowing for HOGF proposals to generate traces conditioned on the values of other traces, at the cost of potential runtime panics if the observed trace is dropped while proposing conditioned inference moves.
gen!{fn drift_proposal<U: Trace>(trace: &U, a_param: f64, b_param: f64) {
let noise ~ beta(a_param, b_param);
let latent ~ mvnormal(trace["latent"], &DMatrix::from_diagonal(dvector![noise, noise]);
}};
This is pseudo-syntax; a trace type like DynamicTrace
would need to be defined for composability.
Example Inference
let bounds = Bounds { xmin: -5., xmax: 5., ymin: -5., ymax: 5. };
let obs_cov = dmatrix![1., -3./5.; -3./5., 2.];
let obs = dvector![0., 0.];
let mut constraints = ChoiceHashMap::<Point>::new();
constraints.set_value("obs", &Rc::new(obs));
let mut trace = model.generate((bounds, obs_cov), constraints);
let a = 2.;
let b = 5.;
for iter in 0..1000 {
let (new_trace, accepted) = gen_rs::mh(&pointed_model, trace, &drift_proposal, (a, b));
trace = new_trace;
}
Notes
ThreadRng
should be initialized uniquely for each generative function.
union
types can be used to support polymorphic choice values, but may require explicit type annotations (depending on how fancy pattern matching can be with Distribution
's generic types).
- It's worth thinking carefully how much functionality could/should be encapsulated in an AST representation, versus in-lined directly via source-to-source translation. Opting for a more verbose inline syntax leverages the common GFI and supports further optimization passes, at the cost of increased procedural macro complexity.