lindermanlab / ssm-jax Goto Github PK
View Code? Open in Web Editor NEWBayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
License: MIT License
Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
License: MIT License
Implement generic SMC algorithm to sample the latent states of a generic SSM. @andrewwarrington already has a working prototype, it just needs to be polished and rolled into main.
See chapter 23.5 of Probabilistic Machine Learning by Murphy (2012).
SSM's hidden Markov model (HMM) objects expose a function to compute the marginal likelihood of the data, summing over the discrete latent states. This function can be automatically differentiated with jax.grad
. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over HMM parameters, using the marginal likelihood and a prior on parameter values.
in https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/lds/models.py#L297 it should say Defaults to a random matrix
.
I noticed that this project has been superceded by the dynamax. However, I cannot find any place where dynamax incorporate SLDS or rSLDS. Will they be included in a future release, or am I overlooking something?
I am also trying to fit my neural data to rSLDS framework. As mentioned in issue #163, is it possible to get the variance explained by each latent so that a reasonable number of states and dimensions could be specified? Thank you so much in advance!
Hey @schlagercollin, I'm running into an issue with the new ABC imports when running in Colab. Is this a Python versioning issue?
TypeError Traceback (most recent call last)
in ()
----> 6 from ssm.lds import GaussianLDS
...
/usr/local/lib/python3.7/dist-packages/ssm/utils.py in ()
101 z2: Sequence[int],
102 K1: Optional[int] = None,
--> 103 K2: Optional[int] = None,
104 ):
105 """
TypeError: 'ABCMeta' object is not subscriptable
replace jax.experimental.optimizers with jax.optax eg in https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/inference/laplace_em.py#L7
SSM's Gaussian linear dynamical system (LDS) objects expose a function to compute the marginal likelihood of the data, integrating over the continuous latent states. This function can be automatically differentiated with jax.grad
. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over LDS parameters, using the marginal likelihood and a prior on parameter values.
@schlagercollin has been working on generic getter/setters for the unconstrained parameters of a model. This is a key component necessary for many downstream feature requests.
HMMs with exponential family emissions admit a simple Gibbs sampling algorithm: alternate between the following two steps:
HMMPosterior.sample()
ExponentialFamilyEmissions.m_step()
, but it will use conditional.sample()
instead of conditional.mode()
.Hi, I'm excited to start using this codebase.
I followed the "installation for development" instructions in the README, but then from ssm.hmm import GaussianHMM
gives me an error about tree_multimap
. It looks like jax.tree_util.tree_multimap
was removed in jax v0.3.16 (changelog) and the latest release is 0.3.20. It look like tree_multimap was simply replaced by jax.tree_utils.tree_map
and that the functionality is the same (jax issue, PR).
In this repo, tree_multimap is only used in ssm.utils.tree_all_equal
, so it should be straightforward to remove.
Best,
Jack
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/setup.py#L21, replace jax==0.2.21
with jax>=0.2.21
.
Since in colab, it uninstalls the default jax 0.2.25.
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/bernoulli-hmm-example.ipynb, you create an HMM with random bernoulli observation model. How can I extract and plot the underlying nstates x ndims
matrix of probabilities? (It's buried behind some TFP class.)
Similarly, in https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/gaussian-hmm-example.ipynb, how do I extract the observation parameters of the learned model.
Dear all,
I am trying to run the example code "GaussianHMM." However, I got an error saying "TypeError: model
must be convertible to dict
(saw: DeviceArray).".
I searched for this error in the Jax community but could not find a solution. Could you please help me out? Thank you very much!
from ssm.hmm import GaussianHMM
import jax.random as jr
# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)
# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
Initializing...
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 12
9 test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
11 # fit it to our sampled data
---> 12 log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
250 if key in bound_args.arguments and bound_args.arguments[key] is not None:
251 bound_args.arguments[key] = \
252 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)
File ~/tmp/ssm-jax/ssm/hmm/base.py:201, in HMM.fit(self, data, covariates, metadata, method, num_iters, tol, initialization_method, key, verbosity)
199 if initialization_method is not None:
200 if verbosity >= Verbosity.LOUD : print("Initializing...")
--> 201 self.initialize(key, data, method=initialization_method)
202 if verbosity >= Verbosity.LOUD: print("Done.", flush=True)
204 if method == "em":
File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
250 if key in bound_args.arguments and bound_args.arguments[key] is not None:
251 bound_args.arguments[key] = \
252 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)
File ~/tmp/ssm-jax/ssm/hmm/base.py:132, in HMM.initialize(self, key, data, covariates, metadata, method)
129 dummy_posteriors = DummyPosterior(one_hot(assignments, self._num_states))
131 # Do one m-step with the dummy posteriors
--> 132 self._emissions.m_step(data, dummy_posteriors)
File ~/tmp/ssm-jax/ssm/hmm/emissions.py:161, in ExponentialFamilyEmissions.m_step(self, dataset, posteriors, covariates, metadata)
145 def m_step(self, dataset, posteriors, covariates=None, metadata=None) -> ExponentialFamilyEmissions:
146 """Update the emissions distribution using an M-step.
147
148 Operates over a batch of data (posterior must have the same batch dim).
(...)
159 emissions (ExponentialFamilyEmissions): updated emissions object
160 """
--> 161 conditional = self._emissions_distribution_class.compute_conditional(
162 dataset, weights=posteriors.expected_states, prior=self._prior)
163 self._distribution = self._emissions_distribution_class.from_params(
164 conditional.mode())
165 return self
File ~/tmp/ssm-jax/ssm/distributions/expfam.py:98, in ExponentialFamilyDistribution.compute_conditional(cls, data, weights, prior)
95 stats = tree_map(np.add, stats, prior.natural_parameters)
97 # Compute the conditional distribution given the stats
---> 98 return cls.compute_conditional_from_stats(stats)
File ~/tmp/ssm-jax/ssm/distributions/expfam.py:75, in ExponentialFamilyDistribution.compute_conditional_from_stats(cls, stats)
73 @classmethod
74 def compute_conditional_from_stats(cls, stats):
---> 75 return get_prior(cls).from_natural_parameters(stats)
File ~/tmp/ssm-jax/ssm/distributions/niw.py:69, in NormalInverseWishart.from_natural_parameters(cls, natural_params)
67 loc = np.einsum("...i,...->...i", s2, 1 / mean_precision)
68 scale = s3 - np.einsum("...,...i,...j->...ij", mean_precision, loc, loc)
---> 69 return cls(loc, mean_precision, df, scale)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:474, in JointDistributionNamed.__new__(cls, *args, **kwargs)
470 model = kwargs.get('model')
472 if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
473 for d in tf.nest.flatten(model)):
--> 474 return _JointDistributionNamed(*args, **kwargs)
475 return super(JointDistributionNamed, cls).__new__(cls)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:323, in _JointDistributionNamed.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
287 def __init__(self,
288 model,
289 batch_ndims=None,
(...)
292 experimental_use_kahan_sum=False,
293 name=None):
294 """Construct the `JointDistributionNamed` distribution.
295
296 Args:
(...)
321 Default value: `None` (i.e., `"JointDistributionNamed"`).
322 """
--> 323 super(_JointDistributionNamed, self).__init__(
324 model,
325 batch_ndims=batch_ndims,
326 use_vectorized_map=use_vectorized_map,
327 validate_args=validate_args,
328 experimental_use_kahan_sum=experimental_use_kahan_sum,
329 name=name or 'JointDistributionNamed')
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_sequential.py:362, in _JointDistributionSequential.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
360 self._model_trackable = model
361 self._model = self._no_dependency(model)
--> 362 self._build(model)
364 super(_JointDistributionSequential, self).__init__(
365 dtype=None, # Ignored; we'll override.
366 batch_ndims=batch_ndims,
(...)
370 experimental_use_kahan_sum=experimental_use_kahan_sum,
371 name=name)
373 # If the model consists entirely of prebuilt distributions with no
374 # dependencies, cache them directly to avoid a sample call down the road.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:334, in _JointDistributionNamed._build(self, model)
332 """Creates `dist_fn`, `dist_fn_wrapped`, `dist_fn_args`, `dist_fn_name`."""
333 if not _is_dict_like(model):
--> 334 raise TypeError('`model` must be convertible to `dict` (saw: {}).'.format(
335 type(model).__name__))
336 [
337 self._dist_fn,
338 self._dist_fn_wrapped,
339 self._dist_fn_args,
340 self._dist_fn_name, # JointDistributionSequential doesn't have this.
341 ] = _prob_chain_rule_model_flatten(model)
TypeError: `model` must be convertible to `dict` (saw: DeviceArray).
Use jax.grad
to implement a generic extended Kalman filter for linear dynamical systems with linear Gaussian dynamics and nonlinear/non-Gaussian emissions. See chapter 18.5 of Probabilistic Machine Learning by Murphy (2012).
To implement the EKF, we will need the unconstrained parameters of a model. See #32
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/lds/emissions.py#L263, you take the Gaussian expected sufficient statistics E[z_t y_t], and then sample from them, before fitting the Poisson model on this sampled data (IIUC). Is the sampling step necessary? Can you use weighted MLE?
Hi, I've been trying to fit a PoissonHmm with some simulations made in NEST https://github.com/nest/nest-simulator, due to the length of the simulation and high spike counts I thought using the refactor would optimise the running time with respect to the other version of the SSM where everything works but it can take a while (not too long to get worried). However when using the jax-ssm I encounter numerical instability in the EM update step, the following assertion is raised no matter the number of iterations or even taking a small sample of the data:
assert np.isfinite(lp), "NaNs in marginal log probability"
I was wondering if there is a known limitation with a large number of spike counts or something that I am missing.
Use jax.grad to implement an unscented Kalman filter for SSMs with linear Gaussian dynamics and nonlinear/non-Gaussian emissions. See chapter 18.5 of Probabilistic Machine Learning by Murphy (2012).
To implement the UKF, we will need the unconstrained parameters of a model. See #32
Is there a way to do this? Essentially I'd like to pass a flag to fit(..., method="em")
that would tell the code to optimize the transitions matrix and keep the emissions fixed (or vice versa!)
LDSs with Gaussian emissions admit a simple Gibbs sampling algorithm: alternate between the following two steps:
LDSPosterior.sample()
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.