exoplanet-dev / celerite2 Goto Github PK
View Code? Open in Web Editor NEWFast & scalable Gaussian Processes in one dimension
Home Page: https://celerite2.readthedocs.io
License: MIT License
Fast & scalable Gaussian Processes in one dimension
Home Page: https://celerite2.readthedocs.io
License: MIT License
Hi dfm,
I am using gaussian process as an interpolator for some control points and I want to calculate the gradient of a function that takes the gp's predicted values as input. A simple example is the sum_interp_gp
function below, where I simply add up the predicted values. It seems that differentiation of gp.predict has not been implemented. Are there plans to do this?
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
/tmp/ipykernel_18624/620407306.py in <module>
----> 1 cf.grad_sum_interp_gp(params)
[... skipping hidden 12 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in grad_sum_interp_gp(self, *args)
144 @partial(jit, static_argnums=(0,))
145 def grad_sum_interp_gp(self, *args):
--> 146 return grad(self.sum_interp_gp)(*args)
147
148 @partial(jit, static_argnums=(0,))
[... skipping hidden 26 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in sum_interp_gp(self, *args)
139 @partial(jit, static_argnums=(0,))
140 def sum_interp_gp(self, *args):
--> 141 mu = self.interp_gp(self.annual, *args)
142 return jnp.sum(mu)
143
[... skipping hidden 17 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in interp_gp(self, tval, *args)
133 gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
134 gp.compute(self.control_points_time)
--> 135 mu = gp.predict(control_points, t=tval, return_var=False)
136 mu = (tval > self.start) * mu + (tval <= self.start) * mean
137 return mu
~/.local/lib/python3.8/site-packages/celerite2/core.py in predict(self, y, t, return_cov, return_var, include_mean, kernel)
472 if return_cov:
473 return cond.mean, cond.covariance
--> 474 return cond.mean
475
476 def condition(self, y, t=None, *, include_mean=True, kernel=None):
~/.local/lib/python3.8/site-packages/celerite2/core.py in mean(self)
125
126 mu = self.gp._zeros_like(self._xs)
--> 127 mu = self._do_dot(alpha, mu)
128
129 if self.include_mean:
~/.local/lib/python3.8/site-packages/celerite2/core.py in _do_dot(self, inp, target)
106 target = target[:, None]
107
--> 108 target = self._do_general_matmul(c, U1, V1, U2, V2, inp, target)
109
110 if is_vector:
~/.local/lib/python3.8/site-packages/celerite2/jax/celerite2.py in _do_general_matmul(self, c, U1, V1, U2, V2, inp, target)
10 class ConditionalDistribution(BaseConditionalDistribution):
11 def _do_general_matmul(self, c, U1, V1, U2, V2, inp, target):
---> 12 target += ops.general_matmul_lower(
13 self._xs, self.gp._t, c, U2, V1, inp
14 )
~/.local/lib/python3.8/site-packages/celerite2/jax/ops.py in general_matmul_lower(t1, t2, c, U, V, Y)
61
62 def general_matmul_lower(t1, t2, c, U, V, Y):
---> 63 Z, F = general_matmul_lower_p.bind(t1, t2, c, U, V, Y)
64 return Z
65
[... skipping hidden 1 frame]
~/.local/lib/python3.8/site-packages/jax/interpreters/ad.py in process_primitive(self, primitive, tracers, params)
280 if not jvp:
281 msg = f"Differentiation rule for '{primitive}' not implemented"
--> 282 raise NotImplementedError(msg)
283 primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
284 if primitive.multiple_results:
NotImplementedError: Differentiation rule for 'celerite2_general_matmul_lower' not implemented
Hi @dfm,
I'm trying to use numpyro to sample a GP with the SHO kernel as follows:
from jax.config import config
config.update('jax_enable_x64', True)
import jax
import jax.numpy as jnp
from celerite2.jax import GaussianProcess, terms
import numpyro.distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS
prior_sigma = 1.0
def numpyro_model(x, yerr, y=None):
mean = sample("mean", dist.Normal(0.0, prior_sigma))
logjitter = sample("logjitter", dist.Normal(-26, 3 * prior_sigma))
logsigma = sample("logsigma", dist.Normal(-11, 3 * prior_sigma))
rho = sample("rho", dist.Normal(1.0, 3 * prior_sigma))
tau = sample("tau", dist.Normal(0.1, prior_sigma))
term = terms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
gp = GaussianProcess(term, mean=mean)
gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)
sample("obs", gp.numpyro_dist(), obs=y)
nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=2,
progress_bar=True,
)
rng_key = jax.random.PRNGKey(34923)
yerr = 1e-8
mcmc.run(rng_key, x, yerr, y=y)
and I'm getting an error with a long traceback that ends:
/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
397 over = OverdampedSHOTerm(*args, **kwargs)
398 under = UnderdampedSHOTerm(*args, **kwargs)
--> 399 if over.Q < 0.5:
400 return over
401 return under
[... skipping hidden 2 frame]
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-139-60f93f4ec4d5> in <module>
1 yerr = 1e-8
----> 2 mcmc.run(rng_key, x, yerr, y=y)
3 samples = mcmc.get_samples()
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
596 else:
597 if self.chain_method == "sequential":
--> 598 states, last_state = _laxmap(partial_map_fn, map_args)
599 elif self.chain_method == "parallel":
600 states, last_state = pmap(partial_map_fn)(map_args)
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
161
162 return tree_map(lambda *args: jnp.stack(args), *ys)
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
654 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
655 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 656 (init_params, pe, grad), is_valid = find_valid_initial_params(
657 rng_key,
658 substitute(
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
395 # Handle possible vectorization
396 if rng_key.ndim == 1:
--> 397 (init_params, pe, z_grad), is_valid = _find_valid_params(
398 rng_key, exit_early=True
399 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
388 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
389 # even if the init_state is a valid result
--> 390 _, _, (init_params, pe, z_grad), is_valid = while_loop(
391 cond_fn, body_fn, init_state
392 )
/usr/local/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
132
133
[... skipping hidden 9 frame]
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
365 z_grad = jacfwd(potential_fn)(params)
366 else:
--> 367 pe, z_grad = value_and_grad(potential_fn)(params)
368 z_grad_flat = ravel_pytree(z_grad)[0]
369 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
247 )
248 # no param is needed for log_density computation because we already substitute
--> 249 log_joint, model_trace = log_density_(
250 substituted_model, model_args, model_kwargs, {}
251 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
/usr/local/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
<ipython-input-137-39666cc8f7df> in numpyro_model(x, yerr, y)
10 tau = sample("tau", dist.Normal(0.1, prior_sigma))
11
---> 12 term = jTerms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
13 gp = jGP(term, mean=mean)
14 gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)
/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
397 over = OverdampedSHOTerm(*args, **kwargs)
398 under = UnderdampedSHOTerm(*args, **kwargs)
--> 399 if over.Q < 0.5:
400 return over
401 return under
[... skipping hidden 2 frame]
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function body_fn at /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:315 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[1].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
It runs just fine if I replace terms.SHOTerm
with terms.UnderdampedSHOTerm
and constrain the hyper parameters to be in the underdamped regime. Any idea what's going on here?
Thanks!
Hi there,
I'm trying to model a light curve (~10^5 points) using celerite2
and PyMC3
as a SHOTerm
+ a mean model composed of a sum of sinusoids (where I have fairly tight priors on frequencies + amplitudes). If I'm reading the documentation for GaussianProcess
correctly, it looks like the mean
parameter needs to either be a scalar, or an object that can be called with a vector of time values.
Assuming I have that correct, I've implemented this model as:
import pymc3 as pm
import pymc3_ext as pmx
import aesara_theano_fallback.tensor as tt
from celerite2.theano import terms, GaussianProcess
with pm.Model() as model:
# A jitter term describing excess white noise
log_jitter = pm.Normal("log_jitter", mu=np.log(err_array.mean()), sigma=5.0)
# SHOTerm
logsigma = pm.Uniform("log_sigma", lower=-15, upper=15)
sigma = pm.Deterministic("sigma",tt.exp(log_sigma))
logtau = pm.Uniform("log_tau", lower=-5, upper=5)
tau = pm.Deterministic("tau", tt.exp(logtau))
#We want Q<1/sqrt(2)
Q = pm.Uniform("Q",lower=0.0,upper=1/np.sqrt(2.0)+0.00001, testval=1/np.sqrt(2.0))
rho = pm.Deterministic("rho", tau * np.pi / Q)
kernel = terms.SHOTerm(sigma=sigma, rho=rho, tau=tau)
#now the non-constant mean
#the mean flux of the light curve
mean_flux = pm.Normal("mean_flux", mu = 1.0, sigma=np.std(flux))
f, f_err = lists_of_frequencies_and_errorbars
a, a_err = lists_of_amplitudes_and_errorbars
#Make them PyMC3 variables
fs = [pm.Uniform(f"f{i}", lower = f[i] - 3*f_err[i], upper=f[i]+ 3*f_err[i]) for i in range(len(f)]
amps = [pm.Uniform(f"a{i}", lower = a[i] - 3*a_err[i], upper=a[i]+ 3*a_err[i]) for i in range(len(f)]
phases = [pmx.Angle(f"phi{i}") for i in range(len(f))]
#Making a callable for celerite2
mean = lambda x: tt.sum([a * tt.sin(2.0*np.pi*f*x + phi) for a,f,phi in zip(amps,fs,phases)],axis=0) + mean_flux
#And add it to the model to we can track it
pm.Deterministic("mean", mean(time_array))
gp = GaussianProcess(
kernel,
t=time_array,
diag=err_array ** 2.0 + tt.exp(2 * log_jitter),
mean=mean,
quiet=True,
)
# Compute the Gaussian Process likelihood and add it into the PyMC3 model
gp.marginal("gp", observed=flux_array)
# Compute the mean model prediction of just the GP
pm.Deterministic("pred", gp.predict(flux_array, include_mean=False))
When I run pmx.optimize()
, this seems to go fairly well, taking ~45 seconds on my laptop with a 10-frequency mean model (35 parameters in total), but when I try to sample using pmx.sample
with more than one core, I get the warning message:
Could not pickle model, sampling singlethreaded.
Sequential sampling (4 chains in 1 job)
I don't get this message when I use a constant mean, so my guess is it has something to do with the callable that gets passed to GaussianProcess
. The model takes quite a while to sample, so I'd really like to run this in parallel.
Is there a workaround? Am I perhaps completely misreading the GaussianProcess
docstring? Oh and for completeness, running
from aesara_theano_fallback import __version__ as tt_version
from celerite2 import __version__ as c2_version
pm.__version__, pmx.__version__, tt_version, c2_version
yields
'3.11.4', '0.1.0', '0.0.4', '0.2.0'
Hi folks, I just discovered this updated version of celerite and am excited to add celerite2
's PyMC3
GP into the Eureka!
package (as I'm finding more and more cases where using a GP would be beneficial). We already have the original celerite
GP working (at least kinda) as well as the george
GP for standard python minimizers and samplers (emcee
, scipy.optimize.minimize
, dynesty
), but we also have a PyMC3
version of our fitting code which allows one to use starry
(ideal for eclipse mapping) and in general PyMC3
's NUTS sampler has allowed for much faster fits (for me at least). I know that jax
and PyMC4
offer advantages over the deprecated PyMC3
and theano
implementation, but over the past year I've already spent more than a fifty hours on getting PyMC3
versions of all our astrophysical and systematic models implemented, I need to be able to use starry
for the astrophysical model, and I'm looking for a fairly quick way to get GPs implemented. I was about to try PyMC3
's built-in GP, but I saw on the old exoplanet
docs that exoplanet
had a faster GP implementation for PyMC3
sampling and then noticed that the code had been migrated here to celerite2
.
However, by default I am unable to access the pymc3
submodule of celerite2
that is mentioned in the documentation. I've tried installing the version on main branch on GitHub (seeing that v0.2.1 doesn't have pymc3 code), but I get an error that 'celerite2' has no attribute 'pymc3'
:
> pip install celerite2[pymc3]@git+https://github.com/exoplanet-dev/celerite2
> ipython
Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:19)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.14.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import celerite2
c
In [2]: celerite2.__version__
Out[2]: '0.3.0rc2.dev11+g74b0705'
In [3]: celerite2.pymc3
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 celerite2.pymc3
AttributeError: module 'celerite2' has no attribute 'pymc3'
Looking at the package on GitHub I see the pymc3 folder under python/celerite2/pymc3, but it seems it isn't being imported in the __init__.py
file. Just adding from celerite2 import pymc3
to the python/celerite2/__init__.py
file wasn't enough to solve the problem either, giving me the following error message when installing:
Building wheels for collected packages: celerite2
Building wheel for celerite2 (pyproject.toml) ... error
error: subprocess-exited-with-error
ร Building wheel for celerite2 (pyproject.toml) did not run successfully.
โ exit code: 1
โฐโ> [27 lines of output]
running bdist_wheel
running build
running build_py
copying python/celerite2/__init__.py -> build/lib.macosx-10.9-x86_64-cpython-39/celerite2
copying python/celerite2/celerite2_version.py -> build/lib.macosx-10.9-x86_64-cpython-39/celerite2
running egg_info
writing python/celerite2.egg-info/PKG-INFO
writing dependency_links to python/celerite2.egg-info/dependency_links.txt
writing requirements to python/celerite2.egg-info/requires.txt
writing top-level names to python/celerite2.egg-info/top_level.txt
reading manifest template 'MANIFEST.in'
warning: no directories found matching 'c++/vendor/eigen/Eigen'
adding license file 'LICENSE'
writing manifest file 'python/celerite2.egg-info/SOURCES.txt'
running build_ext
clang -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -I/Users/tjbell1/miniconda3/envs/eureka_starry/include/python3.9 -c flagcheck.cpp -o flagcheck.o -std=c++17
building 'celerite2.driver' extension
clang -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -Ic++/include -Ic++/vendor/eigen -Ipython/celerite2 -I/private/var/folders/f3/_h5lxm511fv9sjb5d3xrj0jm0000gq/T/pip-build-env-u9cav0rq/overlay/lib/python3.9/site-packages/pybind11/include -I/Users/tjbell1/miniconda3/envs/eureka_starry/include/python3.9 -c python/celerite2/driver.cpp -o build/temp.macosx-10.9-x86_64-cpython-39/python/celerite2/driver.o -std=c++17 -mmacosx-version-min=10.14 -fvisibility=hidden -g0
In file included from python/celerite2/driver.cpp:6:
In file included from python/celerite2/driver.hpp:8:
In file included from c++/include/celerite2/celerite2.h:4:
In file included from c++/include/celerite2/core.hpp:4:
c++/include/celerite2/forward.hpp:4:10: fatal error: 'Eigen/Core' file not found
#include <Eigen/Core>
^~~~~~~~~~~~
1 error generated.
error: command '/usr/bin/clang' failed with exit code 1
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for celerite2
Failed to build celerite2
ERROR: Could not build wheels for celerite2, which is required to install pyproject.toml-based projects
I was able to find the referenced Eigen
package at https://eigen.tuxfamily.org/, and I downloaded the latest version and put the entire contents of that package in celerite2
's (previously empty) c++/vendor/eigen
folder. With that package downloaded and my edit to the __init__.py
file, I'm now able to import and use celerite2.pymc3
.
I thought I'd document my troubleshooting process here to help anyone else trying to do the same thing or to help the devs figure out what changes need to be made to the documentation and/or package.
See exoplanet-dev/exoplanet-core#2 for details.
I think that this must be possible using a forward and backward pass. At the very least, this must be possible when the predictive kernel is the same as the base kernel (by comparison with state-space models), but I feel like it must be possible for general semi-separable matrices.
Hi Dan,
Just a question - there is a great tutorial on using the numpyro backend with Jax.
Is there any documentation on using jax mean models too? eg any pitfalls with definitions, gradients, jit..
All the best,
Ben
Things simplify a bit if we replace P
by t
and c
.
Is your feature request related to a problem? Please describe.
Hi, could you discuss the different priors you use for the GP hyperparameters, be it the SHO kernel for the transit case studies, or the rotation kernel in the stellar variability case. They tend to differ in terms of the functional form as well as starting value. It would be useful to have some discussion on this.
I am happily using celerite2
with no issues on my local computer, but encountered a problem when importing celerite2
on readthedocs. I found a workaround but am posting here in case anyone has the same issue.
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/cmd/build.py", line 293, in build_main
app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/application.py", line 211, in init
self.config = Config.read(self.confdir, confoverrides or {}, self.tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 181, in read
namespace = eval_config_file(filename, tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 371, in eval_config_file
raise ConfigError(msg % traceback.format_exc()) from exc
sphinx.errors.ConfigError: There is a programmable error in your configuration file:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)
Configuration error:
There is a programmable error in your configuration file:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)
More information is available from this failed build:
https://readthedocs.org/api/v2/build/22591584.txt
This was with Python 3.11
Hi Dan and celerite
team
I have the following versions:
exoplanet.__version__ = '0.6.0'
celerite2.__version__ = '0.3.1'
pymc.__version__ = '5.10.4'
I want to model in-transit spot occultations with celerite
, and the GP should only be conditioned on in-transit data. A separate GP with a different kernel and hyperparameters is used to describe a longer timescale correlated signal. I assume they have to be separate GPs because the data will be different lengths.
For a single GP (gp1
), I followed the exoplanet
examples and created a function that outputs the transit light curve and pass that to the mean
keyword in celerite2.pymc.GaussianProcess
. For a second GP (gp2
) my expectation was that I could create a mean function that is a sum of a light curve model and gp1.predict()
. However this doesn't seem to work, it throws AttributeError: '_CeleriteOp' object has no attribute 'rev_op'
. If I add an .eval()
it'll run but the maximum likelihood model isn't a good fit, I suspect something goes wrong behind the scenes.
Is it possible to use the mean of a GP as a mean function in another GP?
Here's what I'm working with so far:
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import pymc as pm
import pymc_ext as pmx
import exoplanet as xo
import pytensor.tensor as pt
from celerite2.pymc import GaussianProcess, terms
np.random.seed(123)
period = np.random.uniform(3,10)
t = np.arange(-0.2, 0.2, 2/60/24)
# The light curve calculation requires an orbit
orbit = xo.orbits.KeplerianOrbit(period=period, t0=0, b=0, duration=0.15, ror=0.1)
# Compute a limb-darkened light curve using starry
u = [0.3, 0.2]
light_curve = np.sum(
xo.LimbDarkLightCurve(u[0], u[1])
.get_light_curve(orbit=orbit, r=0.1, t=t, texp=2/60/24)
.eval(),
axis=-1
)
# Create simulated data
yerr = 3e-4
y = light_curve
M = (t > -0.5*0.15) & (t < 0.5*0.15) # transit mask
y += yerr * np.random.randn(len(y)) # add noise
y += 0.01*t # add linear term
y += 1
# add some spot occultations
locs = [-0.005, 0.03]
widths = [0.008, 0.01]
amps = [0.002, 0.001]
for i in range(len(locs)):
m = (t > (locs[i]-widths[i])) & (t < (locs[i]+widths[i]))
y[m] += amps[i] * np.exp(-(t[m]-locs[i])**2/widths[i]**2)
with pm.Model() as model:
mean = pm.Normal("mean", mu=1, sigma=0.002, initval=1)
# The time of a reference transit for each planet
t0 = pm.Normal("t0", mu=0, sigma=0.01, initval=0)
u = xo.quad_limb_dark("u", initval=[0.3, 0.2])
log_dur = pm.Normal("log_dur", mu=np.log(0.13), sigma=0.1, initval=np.log(0.13))
dur = pm.Deterministic("dur", pt.exp(log_dur))
log_ror = pm.Normal("logr", mu=np.log(0.1), sigma=0.1, initval=np.log(0.1))
ror = pm.Deterministic("r", pt.exp(log_ror))
b = xo.impact_parameter("b", ror=ror, initval=0.3)
star = xo.LimbDarkLightCurve(u[0], u[1])
# Set up a Keplerian orbit for the planets
orbit = xo.orbits.KeplerianOrbit(period=period, t0=t0, b=b, duration=dur, ror=ror)
# Compute the model light curve using starry
def _mean_fn(orbit, mean, r, star, t):
return pt.sum(star.get_light_curve(
orbit=orbit, r=r, t=t, texp=2/60/24),
axis=-1
) + mean
mean_fn = partial(_mean_fn, orbit, mean, ror, star)
pm.Deterministic("light_curves", mean_fn(t))
# GP parameters for the linear trend and white noise
log_sigma = pm.Normal("log_sigma", mu=np.log(0.5*yerr), sigma=0.1)
sigma = pm.Deterministic("sigma", pt.exp(log_sigma))
log_rho_gp = pm.Normal("log_rho_gp", mu=7, sigma=0.5, initval=7)
rho_gp = pm.Deterministic("rho_gp", pt.exp(log_rho_gp))
log_sigma_gp = pm.Normal("log_sigma_gp", mu=-4, sigma=0.5, initval=-4)
sigma_gp = pm.Deterministic("sigma_gp", pt.exp(log_sigma_gp))
kernel = terms.Matern32Term(rho=rho_gp, sigma=sigma_gp)
gp = GaussianProcess(kernel, t=t, diag=yerr**2 + sigma**2,
mean=mean_fn, quiet=True)
pm.Deterministic("gp_preds", gp.predict(y, include_mean=False))
gp.marginal("obs", observed=y)
######################################################################
# problematic part
######################################################################
# GP parameters for spot occultations
log_sigma_spot = pm.Normal("log_sigma_spot", mu=-10, sigma=5, initval=-10)
sigma_spot = pm.Deterministic("sigma_spot", pt.exp(log_sigma_spot))
log_rho_spot = pm.Normal("log_rho_spot", mu=np.log(0.02), sigma=0.5)
rho_spot = pm.Deterministic("rho_spot", pt.exp(log_rho_spot))
kernel2 = terms.Matern32Term(rho=rho_spot, sigma=sigma_spot)
def _mean_fn_spot(gp, orbit, star, mean, y, r, t):
gp_pred = gp.predict(y, t=t, include_mean=False).eval()
lc_pred = (pt.sum(star.get_light_curve(
orbit=orbit, r=r, t=t, texp=2/60/24),
axis=-1
) + mean).eval()
return pt.as_tensor_variable(lc_pred+gp_pred)
spot_fn = partial(_mean_fn_spot, gp, orbit, star, mean, y, ror)
gp2 = GaussianProcess(kernel2, t=t[M], diag=yerr**2 + sigma**2,
mean=spot_fn, quiet=False)
pm.Deterministic("gp_preds_spot", gp2.predict(y[M], include_mean=False))
gp2.marginal("obs_spot", observed=y[M])
######################################################################
map_soln = pmx.optimize(start=model.initial_point())
# plot fit
spot_model = np.zeros_like(t)
spot_model[M] = map_soln["gp_preds_spot"]
full_mod = map_soln["light_curves"]+map_soln["gp_preds"]+spot_model
plt.figure()
plt.plot(t, y, ".k", ms=4, label="data")
plt.plot(t, full_mod, lw=1, label="full model")
plt.plot(t, map_soln["light_curves"], lw=1, ls='--', label="transit")
plt.plot(t, map_soln["gp_preds"]+map_soln["mean"], lw=1, ls=':', label="trend")
plt.plot(t, spot_model+map_soln["mean"], lw=1, ls='-.', label="spot")
plt.legend()
plt.xlim(t.min(), t.max())
plt.ylabel("relative flux")
plt.xlabel("time [days]")
plt.legend(fontsize=10)
_ = plt.title("map model")
plt.show()
Hi @dfm,
Today I've built celerite2 from source following the recommendations on the install docs. I'm trying to do something simple, like this
from celerite2.jax import terms
sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()
but I'm getting the following error
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [1], line 4
1 from celerite2.jax import terms
3 sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
----> 4 sho.get_coefficients()
File ~/git/celerite2/python/celerite2/jax/terms.py:36, in Term.get_coefficients(self)
35 def get_coefficients(self):
---> 36 raise NotImplementedError("subclasses must implement this method")
NotImplementedError: subclasses must implement this method
At first I thought this could be an accident of the multiple SHOTerm
implementations, for example, here
celerite2/python/celerite2/jax/terms.py
Lines 473 to 478 in e75dd45
and here
celerite2/python/celerite2/jax/terms.py
Line 481 in e75dd45
but commenting the first one out doesn't solve the problem.
Any ideas? Thanks!
Hi, I was trying to install celerite2 from source code following the instructions in the documentation. But while testing the code I got an error:
python -m pytest -v python/test โโฏ
=================================================================================== test session starts ===================================================================================
platform darwin -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0 -- /opt/anaconda3/envs/pymc/bin/python
cachedir: .pytest_cache
rootdir: /Users/ajaysharma/celerite2
configfile: pyproject.toml
plugins: cov-4.1.0
collected 0 items / 1 error
========================================================================================= ERRORS ==========================================================================================
______________________________________________________________________________ ERROR collecting test session ______________________________________________________________________________
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/runner.py:341: in from_call
result: Optional[TResult] = func()
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/runner.py:372: in <lambda>
call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:749: in collect
for x in self._collectfile(path):
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:588: in _collectfile
ihook = self.gethookproxy(fspath)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:555: in gethookproxy
my_conftestmodules = pm._getconftestmodules(
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:609: in _getconftestmodules
mod = self._importconftest(conftestpath, importmode, rootpath)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:657: in _importconftest
self.consider_conftest(mod)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:739: in consider_conftest
self.register(conftestmodule, name=conftestmodule.__file__)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:491: in register
ret: Optional[str] = super().register(plugin, name)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_manager.py:164: in register
hook._maybe_apply_history(hookimpl)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_hooks.py:559: in _maybe_apply_history
res = self._hookexec(self.name, [method], kwargs, False)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_manager.py:115: in _hookexec
return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
python/test/pymc/conftest.py:13: in pytest_configure
pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing"
E AttributeError: 'PyTensorConfigParser' object has no attribute 'gcc'
================================================================================= short test summary info =================================================================================
ERROR - AttributeError: 'PyTensorConfigParser' object has no attribute 'gcc'
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
==================================================
Please help me to resolve this error. Any help would be appreciated.
Sys info. :
Mac OS M1 chip - version 14.1.1
Python - latest version
I've implemented a first pass at derivative models but there are some open questions so I'm going to remove it from #19 and just post it here for future work:
class _DerivativeHelperTerm(terms.Term):
def __init__(self, term):
self.term = term
def get_coefficients(self):
coeffs = self.term.get_coefficients()
a, b, c, d = coeffs[2:]
final_coeffs = [
-coeffs[0] * coeffs[1],
coeffs[1],
b * d - a * c,
-(a * d + b * c),
c,
d,
]
return final_coeffs
class DerivativeLatentTerm(LatentTerm):
def __init__(self, term, *, value_amplitude, gradient_amplitude):
if term.__requires_general_addition__:
raise TypeError(
"You cannot perform operations on a term that requires general"
" addition, it must be the outer term in the kernel"
)
val_amp = np.atleast_1d(value_amplitude)
grad_amp = np.atleast_1d(gradient_amplitude)
if val_amp.ndim != 1 or val_amp.shape != grad_amp.shape:
raise ValueError("Dimension mismatch between amplitudes")
amp = np.stack((val_amp, grad_amp), axis=-1)
M = amp.shape[0]
ar, cr, ac, bc, cc, dc = term.get_coefficients()
Jr = len(cr)
Jc = len(cc)
J = Jr + 2 * Jc
left = np.zeros((2, 4 * J))
right = np.zeros((2, 4 * J))
if Jr:
left[0, :Jr] = 1.0
left[0, 2 * Jr : 3 * Jr] = -1.0
left[1, Jr : 2 * Jr] = 1.0
left[1, 3 * Jr : 4 * Jr] = 1.0
right[0, : 2 * Jr] = 1.0
right[1, 2 * Jr : 4 * Jr] = 1.0
if Jc:
J0 = 4 * Jr
for J0 in [4 * Jr, 4 * Jr + 4 * Jc]:
left[0, J0 : J0 + Jc] = 1.0
left[0, J0 + 2 * Jc : J0 + 3 * Jc] = -1.0
left[1, J0 + Jc : J0 + 2 * Jc] = 1.0
left[1, J0 + 3 * Jc : J0 + 4 * Jc] = 1.0
right[0, J0 : J0 + 2 * Jc] = 1.0
right[1, J0 + 2 * Jc : J0 + 4 * Jc] = 1.0
self.left = np.dot(amp, left)[:, :, None]
self.right = np.dot(amp, right)[:, :, None]
# self.left = np.empty((M, 4 * J, 1))
# self.left[:] = np.nan
# self.left[:] = np.dot(amp, left)[:, :, None]
# self.right = np.zeros((M, 4 * J, M))
# right = np.dot(amp, right)
# for m in range(M):
# self.right[m, :, m] = right[m]
# self.left = np.reshape(self.left, (M, -1, 1))
# self.right = np.reshape(self.right, (M, -1, 1))
print(self.left)
print(self.right)
# print(self.right)
# assert 0
dterm = _DerivativeHelperTerm(term)
d2term = terms.TermDiff(term)
super().__init__(
terms.TermSum(term, dterm, dterm, d2term),
dimension=2,
left_latent=self._left_latent,
right_latent=self._right_latent,
)
def _left_latent(self, t, inds):
return self.left[inds]
def _right_latent(self, t, inds):
return self.right[inds]
My apologies for this rather silly question. I'm just learning the basics at the moment. (The following example from celerite 1 already clarified things: https://celerite.readthedocs.io/en/stable/tutorials/normalization/)
The Fourier convention of the PSD is a bit confusing. I assume the factor
The point were it gets really confusing is the documentation of
It could also be helpful to mention the integrated autocorrelation time in the documentation of the SHOTerm, which is
Sorry for stumbling over these little things, and thanks for the nice implementation.
This should be possible using semi-separable matrices, but it'll take a little bit of reimplementation in the backend. Hopefully it's not too terrible!
Some background: I want to model a GP for variations in flux as a function of roll-angle. Most of our observations (think HST or CHEOPS) last for a few orbits. Each orbit shows an extremely self-similar (and often not particularly sinusoidal) variation due to e.g. Earth limb, lunar glint, temperature effects, etc. Hence, including a GP on roll-angle vs flux is perfect.
My method thus far has been to sort the data into increasing roll angle and then unsort afterwards to get the variation as a function of time. However, roll angle is not a continuous observation, and for some targets we have enough data that some measurements are extremely close in roll angle to others (<1e-4 degrees), yet not particularly correlated (as they're ~days apart in time). In these cases, celerite2 simply breaks - it throws a LinAlgError. I'm not sure of the maths, but likely because of a large difference between two extremely close points, although it seems to break even when the difference between such points are consistent with the variance.
Here is a working (well, non-working) example:
# Ten fake "orbits" with varying roll angle from ~50 to ~250deg
t=np.hstack([np.linspace(50+5*np.random.random(),250+5*np.random.random(),500) for c in range(10)])
with pm.Model():
rollangle_w0=0.1 #NB - at higher rollangle_w0, we get a different error - a Theano Assert error
rollangle_S0=120.0
kern=theano_terms.SHOTerm(S0=rollangle_S0, w0=rollangle_w0, Q=1/np.sqrt(2))#, mean = phot_mean)
gp=celerite2.theano.GaussianProcess(kern, np.sort(t), mean=0.0) #Sorting to
gp.compute(t=t, diag=np.tile(0.25,len(t))**2)
Relatedly, even when celerite does not break outright, it is consistently forced into extremely short-timescale variations despite a strong prior against such over-fitting...
I think this is all because of an assumption that each x measurement in effectively instantaneous - in our case we have x values (roll angles) which are actually the average across some dx range which overlaps with the neighbouring values. I'm not sure it would be possible, but is there any way to incorporate a value (or array) of dx
which would stop this error/overfitting?
Alternatively, is there some way to use a Periodic kernel to enforce similar variation across all points? I've tried adapted the RotationKernel, which works well for low-frequency variation on the order of 100degree wavelength, but is typically unable to model higher-frequency variation with wavelengths the order of 10s of degrees. But maybe there is a combination of a periodic term and e.g. a spikier Matern32 kernel?
Hi,
I'm trying to install from source following the directions here. The reason I'm installing from source rather than the pre-compiled binaries is that I'm hoping to try out the Kronecker functionality. It works fine on my mac, but when I try to install on the cluster I get the error:
ERROR: Command errored out with exit status -11: /gscratch/home/tagordon/jwst/celerite2/env/bin/python -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/gscratch/home/tagordon/jwst/celerite2/setup.py'"'"'; file='"'"'/gscratch/home/tagordon/jwst/celerite2/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(file);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, file, '"'"'exec'"'"'))' develop --no-deps Check the logs for full command output.
This happens regardless of whether I'm install from the kron branch or not. Any help is appreciated!
Thanks,
Tyler
Hi @dfm,
I'm interested in doing model comparison with LOO-CV and/or WAIC via arviz. My model (w/ JAX+numpyro) relies on celerite2 for computing the log likelihood. The likelihood returned by celerite2 is a single value, but the pointwise predictive accuracy methods expect pointwise logp's. I think that's why when comparing models, I get warnings that say things like:
The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your
model to make sure it returns element-wise logp.
But I'm not sure that a logp can be defined element-wise for a GP. So: is there a way to get element-wise logp's from a GP?
Thanks as always!
Please, can you add a C++ example to the documentation? celerite
C++ doc's example looks helpful but I'm not sure if it is applicable here. Thank you!
Hi @dfm,
I've found a set of SHO kernel hyperparameters which imitate solar granulation by fitting two SHO kernels' PSDs to the measured PSD of the Sun. I'd like to then draw samples from the kernel with the best-fit PSD, like this:
import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
from celerite2 import terms, GaussianProcess
duration = 100 * u.d
cadence = 15 * u.s
params = np.array(
[[2.98095799e+03, 1.88495559e+02, 1.00000000e-02],
[4.97870684e-02, 5.02654825e+04, 1.00000000e-01]]
)
kernel = (
# Granulation terms
terms.SHOTerm(S0=params[0, 0], w0=params[0, 1], Q=params[0, 2]) +
terms.SHOTerm(S0=params[1, 0], w0=params[1, 1], Q=params[1, 2])
)
times = np.arange(0, duration.to(u.s).value, cadence.to(u.s).value, dtype=np.float64) * u.s
x = times.value
gp = GaussianProcess(kernel, t=x)
# Get samples with the kernel's PSD
y = gp.sample()
def power_spectrum(fluxes, d=60):
"""
Compute the power spectrum of ``fluxes`` in units of [ppm^2 / microHz].
Parameters
----------
fluxes : `~numpy.ndarray`
Fluxes with zero mean.
d : float
Time between samples [s].
Returns
-------
freq : `~numpy.ndarray`
Frequencies
power : `~numpy.ndarray`
Power at each frequency in units of [ppm^2 / microHz]
"""
fft = np.fft.rfft(1e6 * fluxes)
power = (fft * np.conj(fft)).real * 1e-6 / len(fluxes)
freq = np.fft.rfftfreq(len(fluxes), d)
return freq, power
# Get the power spectrum of the samples
freq, power = power_spectrum(y-1, d=cadence.to(u.s).value)
freq *= 1e6
# plot the power spectrum of samples, compare with kernel PSD
plt.loglog(freq, power, marker=',', lw=0)
plt.loglog(freq, kernel.get_psd(2*np.pi*freq)/2/np.pi, marker=',', lw=0)
But as you can see in the plot above, it looks like the kernel is just generating white noise, even though the kernel PSD is well behaved. Could this be a precision issue? Maybe I need to set eps
to something?
This behavior crept up when switching from celerite to celerite2, and I think it has to do with the scales of the hyperparameters since they were each defined in log-space before and they're not anymore in celerite2.
Thanks!
When we initialize the doubleSHO, a.k.a. RotationTerm in celrite2, we need (*, sigma, period, Q0, dQ, f). I understand that the idea is to pass five parameters that are will be converted to S1, w1, Q1, S2, w2, Q2 needed for each SHO kernel. So, far so good!
However, I am trying to do is to build this kernel in celerite. See here, under "double_SHOTerm()"
https://github.com/3fon3fonov/exostriker/blob/exostriker-ready/exostriker/lib/RV_mod/GP_kernels.py
My doubt is that use inside the opt.func/mcmc/nested sampling one needs to continuously update the GP while it samples, and compute()
...
gps.set_parameter_vector([sampled GP params])
gps.compute()
...
however, it seems that set_parameter_vector() seems to expect six parameters, these being the 2xSHO parameters, is that right?
So to update the GP model I need six parameters (which BTW, must always be the individual 2xSHOTerm() in log!), but I sample five, which are; sigma, period, Q0, dQ, and f. If I understand correctly, the way how is done now, inside the loop/sample run I must always initialize the GP and then compute().
I wrote a simple function
def init_dSHOKernel(params):
kernel =GP_kernels.double_SHOTerm(
sigma_dSHO=params[0],
period_dSHO=params[1],
Q0_dSHO=params[2],
dQ_dSHO=params[3],
f_dSHO=params[4])
gps = celerite.GP(kernel, mean=0.0)
return gps
...
#and then during the sampling:
gps = init_dSHOKernel(param_vect)
gps.compute()
This seems to work but is quite a compromise, I think. Any comments on this will be highly appreciated. For example, how is this done in celerite2? I will eventually migrate to celerite2, but I would like to have the the option to use the original version too with more kernels.
I am trying to install celerite2 using pip (version=20.2.3) with python3.7 on an Ubuntu18.04 machine, however I get the following error when pip is trying to build wheels for celerite2:
ERROR: Failed building wheel for celerite2
Failed to build celerite2
ERROR: Could not build wheels for celerite2 which use PEP 517 and cannot be installed directly
Thanks,
Tom
I've derived the algorithms for matrix multiplication and solves, but I haven't been able to work out the factorization algorithm yet. There don't seem to be numerical issues for the ops that I've derived so far, but I haven't extensively tested it. This would be interesting because it would allow parallel implementation on a GPU.
Hi,
I am using emcee to maximize a likelihood function to which I pass a celerite2.GaussianProcess object. I store the chain in a h5 file. When I start the emcee sampling from scratch there is no issue, but if I restart the sampling loading the h5 file then the gp.compute() instruction throws this error:
File ".../python3.6/site-packages/celerite2/core.py", line 313, in compute self._t, self._diag, c=self._c, a=self._a, U=self._U, V=self._V File ".../python3.6/site-packages/celerite2/terms.py", line 157, in get_celerite_matrices c.resize(J, refcheck=False) ValueError: cannot resize this array: it does not own its data
How can I fix it? Thanks
I'm interested in being able to model different subsets of the data with different kernels, i.e. K = K1ย + K2, where K2 could have zeros everywhere except blocks corresponding to the target data subset. As you mentioned this isn't trivial with celerite, but could work in principle, i.e. by generalizing the kernel to be k(tau=|ti-tj|) = a_1(ti) * a_2(tj) * exp(-ctau) * cos(dtau) + b_1(ti) * b_2(tj) * exp(-ctau) * sin(dtau)
Describe the bug
The exoplanet software, when conda installed, automatically adds the package theano-pymc, which seems to conflict with my installed version of theano. When I have, without exoplanet, theano version 1.0.5 installed on my computer, the command
from celerite2.theano import terms
executes normally. After a conda installation of exoplanet version 0.4.4, theano-pymc version 1.1.5 is automatically installed. Then, the command
from celerite2.theano import terms
ends in an attribute error, I have attached a picture. This is all done on a freshly created conda environment, with minimal packages installed. My advisor has also been able to recreate this problem on a completely separate machine.
To Reproduce
It is not necessary to create a fresh environment, but that has the best hopes for reproduction:
conda create -n myenv python
Then install any files one needs, I prefer using Jupyter Notebook, and so I run
pip install notebook
Install celerite2 and theano:
conda install -c conda-forge theano python -m pip install -U celerite2
Then running the following command in either a jupyter notebook or any short python script should terminate successfully:
from celerite2.theano import terms
Next, installing any version of exoplanet seems to override this. I did it with the latest version:
conda install -c conda-forge exoplanet
Now when I run the command
from celerite2.theano import terms
The code ends in an Attribute error, a picture of which is attached.
Expected behavior
I expect the command
from celerite2.theano import terms
To execute without error.
Your setup (please complete the following information):
Additional context
With this setup, I have installed and successfully run Celerite2 tutorials, provided they do not use the directory celerite.theano.
Additionally, I have run exactly copied exoplanet tutorials. Those that do not use celerite2 run to accurate completion. Those that use celerite2 end in failure.
Hi,
I'm following your tutorial, and in the pymc3 section the code fails with this error:
Exception: ('Compilation failed (return status=254): clang-10: error: unable to execute command: Segmentation fault: 11. clang-10: error: dsymutil command failed due to signal (use -v to see invocation). ', '[Elemwise{true_div,no_inplace}(<TensorType(float64, row)>, <TensorType(float64, matrix)>)]')
I've not found how to fix it and it seems it's related to Theano.
I'm using python 3.7 and OS X Mojave 10.14.6
Do you have any idea how to fix this??
Thanks.
Here's an example:
import numpy as np
import celerite2
def dot_tril(t1, t2, U, V, c, Y):
N, J = U.shape
M = V.shape[0]
Nrhs = Y.shape[1]
assert V.shape == (M, J)
assert c.shape == (J,)
assert Y.shape == (M, Nrhs)
Z = np.zeros((N, Nrhs))
Fn = np.zeros((J, Nrhs))
n = 0
m = 0
while n < N:
if m < M and t2[m] <= t1[n]:
if m > 0:
Fn = np.exp(-c * (t2[m] - t2[m - 1]))[:, None] * Fn
Fn += np.outer(V[m], Y[m])
m += 1
else:
if m > 0:
Z[n] = np.dot(U[n], np.exp(-c * (t1[n] - t2[m - 1]))[:, None] * Fn)
n += 1
return Z
def dot_triu(t1, t2, U, V, c, Y):
N, J = U.shape
M = V.shape[0]
Nrhs = Y.shape[1]
assert V.shape == (M, J)
assert c.shape == (J,)
assert Y.shape == (M, Nrhs)
Z = np.zeros((N, Nrhs))
Fn = np.zeros((J, Nrhs))
n = N - 1
m = M - 1
while n >= 0:
if m >= 0 and t2[m] > t1[n]:
if m < M - 1:
Fn = np.exp(-c * (t2[m + 1] - t2[m]))[:, None] * Fn
Fn += np.outer(V[m], Y[m])
m -= 1
else:
if m < M - 1:
Z[n] = np.dot(U[n], np.exp(-c * (t2[m + 1] - t1[n]))[:, None] * Fn)
n -= 1
return Z
np.random.seed(42)
t = np.sort(np.random.uniform(0, 3, 1600))
x = np.sort(np.random.uniform(-1, 5, 3049))
term = celerite2.terms.....
a1, U1, V1, P1 = term.get_celerite_matrices(t, np.zeros_like(t))
a2, U2, V2, P2 = term.get_celerite_matrices(x, np.zeros_like(x))
_, cr, _, _, cc, _ = term.get_coefficients()
c = np.concatenate((cr, cc, cc))
y = np.sin(x)[:, None]
inds = np.searchsorted(t, x)
lower = np.arange(len(t))[:, None] >= inds[None, :]
A = np.dot(np.exp(-c * t[:, None]) * U1, (np.exp(c * x[:, None]) * V2).T)
L = np.zeros_like(A)
L[lower] = A[lower]
assert np.allclose(dot_tril(t, x, U1, V2, c, y), np.dot(L, y))
upper = np.arange(len(t))[:, None] < inds[None, :]
A = np.dot(np.exp(c * t[:, None]) * V1, (np.exp(-c * x[:, None]) * U2).T)
U = np.zeros_like(A)
U[upper] = A[upper]
assert np.allclose(dot_triu(t, x, V1, U2, c, y), np.dot(U, y))
Hi there ๐
I noticed the online documentation has a warning about missing font families:
findfont: Generic family 'sans-serif' not found because none of the following families were found: Liberation Sans
I see this on Safari and Chrome, so it's not a browser thing, it appears to be embedded in the rendered docs?
The warning repeats 127 times, filling the screen and requiring long scrolls to navigate to the next series of code cells after a cell involving a plot.
I'm not sure of the cause. One could imagine the font server coincidentally had downtime the last time the docs were rendered, in which case simply triggering a re-render of RTD could fix the issue. Probably more complicated than that. ๐คท
Here is a minimal code block to reproduce the error
from celerite2.terms import Term
class CosineTerm(Term):
def __init__(self, omega_j, sigma_j):
self.omega_j = omega_j
self.sigma_j = sigma_j
def get_coefficients(self):
return 0, 0, self.sigma_j, 0, 0, self.omega_j
kernel = CosineTerm(1,0.1) + CosineTerm(2,0.2)
kernel.get_coefficients()
The following error is thrown at line 161 of terms.py
ValueError: zero-dimensional arrays cannot be concatenated
This will require updating the backend to iterate over the batch dimension, but that shouldn't be too terribly hard. Then, we'd need to add a simple batching function.
One question is how to interface batching with the terms
interface.
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching
This is likely a lame question, but I will shoot.
Which celerite/celerite2 kernel could be used as Damped Random Walk Model? This is a model used for analyzing quasar light curves. My understanding is that this is a GP model, but I can't find such a kernel, and at this point I do not fell confident enough to build it in celerite/celerite2. So what are my alternatives?
It should be possible to write the JVP ops using existing celerite primitives. This would allow support for higher order differentiation and perhaps it won't cause a significant computational overhead.
For example, the matmul_lower
JVP can be implemented as follows:
def matmul_lower_jvp(arg_values, arg_tangents):
def make_zero(x, t):
return lax.zeros_like_array(x) if type(t) is ad.Zero else t
t, c, U, V, Y = arg_values
tp, cp, Up, Vp, Yp = (
make_zero(x, t) for x, t in zip(arg_values, arg_tangents)
)
Ut = -(c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * U + Up
Vt = (c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * V + Vp
Zp = matmul_lower(t, c, U, V, Yp)
Zp += matmul_lower(t, c, Ut, V, Y)
Zp += matmul_lower(t, c, U, Vt, Y)
return matmul_lower_p.bind(t, c, U, V, Y), (Zp, None)
But I haven't figured out the correct transpose yet.
Hi,
is there any celerite.terms.JitterTerm in celerite2?
Thanks
Due to a bug in PyMC (pymc-devs/pymc#6981), setting the option mentioned in the title caused an error in test_marginal
for PyMC after updating to v5 (#91).
A PR fixing this (pymc-devs/pymc#6982) was merged into PyMC, but not yet released. The code below can be uncommented when the fix is released.
celerite2/python/test/pymc/conftest.py
Lines 10 to 11 in 7feb2e3
@tagordon, @ericagol: I have a proposal.
I expect that a common use case would be something like Rubin where you won't ever have multiple bands observed simultaneously, therefore there's a big overhead introduced by building the full matrices and then masking. So:
NxJ
matrix Alpha
which you would multiply into U
and V
and then square, sum along the 2nd axis and then multiply into a
. I think that this would be equivalent to the low rank Kronecker model with missing data.celerite2 version: 0.1.1.dev4+gcee9a0f
theano: 1.0.11
python: 3.6.7
I keep running into a LinAlgError that I can't figure out. I was expecting to be able to sample a Gaussian process based on the exoplanet
example (https://gallery.exoplanet.codes/en/latest/tutorials/quick-tess/) and modified it a little bit, but couldn't get it to run with a Matern nor SHOTerm kernel.
Here's the most minimal working example I could make.
import pymc3 as pm
import numpy as np
from celerite2.theano import terms, GaussianProcess
x_JD = np.array([2458899.75295227, 2458899.75992418, 2458899.76689609,
2458899.773868 , 2458899.78083991, 2458899.78781182,
2458899.79478373, 2458899.80175564, 2458899.80872755,
2458899.81569946, 2458899.82267136, 2458899.82964327,
2458899.83661518, 2458899.84358709, 2458899.850559 ,
2458899.85753091, 2458899.86450282])
x = x_JD - np.median(x_JD)
y = np.array([ 999.91776623, 999.70908473, 999.41758557, 1000.42187202,
1000.33801407, 999.57517094, 999.72223547, 1000.19105569,
1000.27274928, 1000.78978838, 1000.08780841, 1000.29321361,
1000.17255564, 1000.40390643, 999.97807388, 1000.60678241,
999.35199641])
with pm.Model() as model:
resid = y
sigma_lc = pm.Lognormal("sigma_lc", mu=-1, sigma=3)
rho_gp = pm.Lognormal("rho_gp", mu=0, sd=10)
sigma_gp = pm.Lognormal("sigma_gp", mu=-1, sigma=3)
kernel = terms.Matern32Term(sigma=sigma_gp,rho=rho_gp)
gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
gp.marginal("gp", observed=resid)
pm.Deterministic("gp_pred", gp.predict(resid))
trace = pm.sample(tune=3000,
draws=3000,
start=model.test_point,
cores=2,
chains=2,
init="adapt_full",
target_accept=0.9)
/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py:468: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
FutureWarning,
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_full...
/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/quadpotential.py:514: UserWarning: QuadPotentialFullAdapt is an experimental feature
warnings.warn("QuadPotentialFullAdapt is an experimental feature")
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sigma_gp, rho_gp, sigma_lc]
1.22% [146/12000 00:03<04:47 Sampling 2 chains, 0 divergences]
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 974, in __call__
if output_subset is None
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 913, in rval
r = p(n, [x[0] for x in i], o)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 120, in perform
func(*args)
celerite2.backprop.LinAlgError: failed to factorize or solve matrix
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 138, in run
self._start_loop()
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 192, in _start_loop
point, stats = self._compute_point()
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 217, in _compute_point
point, stats = self._step_method.step(self._point)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 273, in step
apoint, stats = self.astep(array)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 169, in astep
hmc_step = self._hamiltonian_step(start, p0, step_size)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 185, in _hamiltonian_step
divergence_info, turning = tree.extend(direction)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 276, in extend
self.right, self.depth, floatX(np.asarray(self.step_size))
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 369, in _build_subtree
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 369, in _build_subtree
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 365, in _build_subtree
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 363, in _build_subtree
return self._single_step(left, epsilon)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 325, in _single_step
right = self.integrator.step(epsilon, left)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/integration.py", line 69, in step
return self._step(epsilon, state)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/integration.py", line 102, in _step
logp = self._logp_dlogp_func(q_new, q_new_grad)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/model.py", line 733, in __call__
output = self._theano_function(array)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 989, in __call__
storage_map=getattr(self.fn, "storage_map", None),
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/link.py", line 343, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 974, in __call__
if output_subset is None
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 913, in rval
r = p(n, [x[0] for x in i], o)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 120, in perform
func(*args)
celerite2.backprop.LinAlgError: failed to factorize or solve matrix
Apply node that caused the error: _CeleriteOp{name='factor_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]})
Toposort index: 34
Inputs types: [TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(17,), (2,), (17,), (17, 2), (17, 2)]
Inputs strides: [(8,), (8,), (8,), (16, 8), (16, 8)]
Inputs values: ['not shown', array([5.8092922e-06, 5.8092922e-06]), 'not shown', 'not shown', 'not shown']
Outputs clients: [[_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3), Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}(TensorConstant{(1,) of -0.5}, _CeleriteOp{name='factor_fwd', quiet=False}.0, TensorConstant{(1,) of 0.5}, Elemwise{Sqr}[(0, 0)].0), Elemwise{Composite{((-i0) / i1)}}(InplaceDimShuffle{0}.0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{TrueDiv}[(0, 0)](Elemwise{Sqr}[(0, 0)].0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{Log}[(0, 0)](_CeleriteOp{name='factor_fwd', quiet=False}.0)], [_CeleriteOp{name='solve_lower_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}), _CeleriteOp{name='solve_lower_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}, _CeleriteOp{name='solve_lower_fwd', quiet=False}.0, _CeleriteOp{name='solve_lower_fwd', quiet=False}.1, IncSubtensor{InplaceInc;::, int64}.0), _CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)], [_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)]]
Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-85507f45a7d6>", line 28, in <module>
gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 210, in __init__
self.compute(t, **kwargs)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 316, in compute
self._do_compute(quiet)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/celerite2.py", line 84, in _do_compute
self._t, self._c, self._a, self._U, self._V
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 642, in __call__
node = self.make_node(*inputs, **kwargs)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in make_node
for spec in self.spec["outputs"] + self.spec["extra_outputs"]
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in <listcomp>
for spec in self.spec["outputs"] + self.spec["extra_outputs"]
HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.
"""
The above exception was the direct cause of the following exception:
LinAlgError Traceback (most recent call last)
LinAlgError: failed to factorize or solve matrix
Apply node that caused the error: _CeleriteOp{name='factor_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]})
Toposort index: 34
Inputs types: [TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(17,), (2,), (17,), (17, 2), (17, 2)]
Inputs strides: [(8,), (8,), (8,), (16, 8), (16, 8)]
Inputs values: ['not shown', array([5.8092922e-06, 5.8092922e-06]), 'not shown', 'not shown', 'not shown']
Outputs clients: [[_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3), Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}(TensorConstant{(1,) of -0.5}, _CeleriteOp{name='factor_fwd', quiet=False}.0, TensorConstant{(1,) of 0.5}, Elemwise{Sqr}[(0, 0)].0), Elemwise{Composite{((-i0) / i1)}}(InplaceDimShuffle{0}.0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{TrueDiv}[(0, 0)](Elemwise{Sqr}[(0, 0)].0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{Log}[(0, 0)](_CeleriteOp{name='factor_fwd', quiet=False}.0)], [_CeleriteOp{name='solve_lower_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}), _CeleriteOp{name='solve_lower_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}, _CeleriteOp{name='solve_lower_fwd', quiet=False}.0, _CeleriteOp{name='solve_lower_fwd', quiet=False}.1, IncSubtensor{InplaceInc;::, int64}.0), _CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)], [_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)]]
Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-85507f45a7d6>", line 28, in <module>
gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 210, in __init__
self.compute(t, **kwargs)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 316, in compute
self._do_compute(quiet)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/celerite2.py", line 84, in _do_compute
self._t, self._c, self._a, self._U, self._V
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 642, in __call__
node = self.make_node(*inputs, **kwargs)
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in make_node
for spec in self.spec["outputs"] + self.spec["extra_outputs"]
File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in <listcomp>
for spec in self.spec["outputs"] + self.spec["extra_outputs"]
HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-1-85507f45a7d6> in <module>
36 chains=2,
37 init="adapt_full",
---> 38 target_accept=0.9)
~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
555 _print_step_hierarchy(step)
556 try:
--> 557 trace = _mp_sample(**sample_args, **parallel_args)
558 except pickle.PickleError:
559 _log.warning("Could not pickle model, sampling singlethreaded.")
~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1474 try:
1475 with sampler:
-> 1476 for draw in sampler:
1477 trace = traces[draw.chain - chain]
1478 if trace.supports_sampler_stats and draw.stats is not None:
~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py in __iter__(self)
478
479 while self._active:
--> 480 draw = ProcessAdapter.recv_draw(self._active)
481 proc, is_last, draw, tuning, stats, warns = draw
482 self._total_draws += 1
~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
358 else:
359 error = RuntimeError("Chain %s failed." % proc.chain)
--> 360 raise error from old_error
361 elif msg[0] == "writing_done":
362 proc._readable = True
RuntimeError: Chain 0 failed.
Perhaps it should be implemented within this repo?
I started implementing a basic numpy version on the kron
branch, based on Tyler's Theano implementation.
ping @tagordon
I don't really use these libraries (yet!) so I'm not sure exactly what the use cases would look like. Any help would be appreciated.
nn.Module
s work so it's probably all a mess.There is an inconsistency between the celerite2 Getting Started tutorial and the Exoplanet tutorials in the use of GaussianProcess
and compute
. Specifically, in the celerite2 tutorial compute
is called explicitly but in the Exoplanet tutorials that include celerite2, compute
is implicitly called through supplying the t
keyword.
I don't think this is a particular problem, but it would be good if the main GaussianProcess
docstring stated that if t
is supplied then the compute method is automatically run.
In addition, the theano.GaussianProcess
class doesn't have a docstring so this would benefit from having the same docstring as the main GaussianProcess
class, also with the addition of the note about compute.
Hi @dfm,
Quick question: it seems like you can't jit compile the example function numpyro_model
in this tutorial because of the way celerite2 is written. I'm wondering if this is intentional or if I'm misunderstanding something about numpyro models. Should numpyro_model
be jit-able? Thanks!
Very long traceback below:
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-5-8f42ec9b0711> in <module>
69 rng_key = random.PRNGKey(34923)
---> 70 mcmc.run(rng_key, t, yerr, y=y)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
501 if self.chain_method == 'sequential':
--> 502 states, last_state = _laxmap(partial_map_fn, map_args)
503 elif self.chain_method == 'parallel':
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
158 x = jit(_get_value_from_index)(xs, i)
--> 159 ys.append(f(x))
160
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
342 collection_size = collection_size if collection_size is None else collection_size // self.thinning
--> 343 collect_vals = fori_collect(lower_idx,
344 upper_idx,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
310 for i in t:
--> 311 vals = jit(_body_fn)(i, vals)
312 t.set_description(progbar_desc(i), refresh=False)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in _body_fn(i, vals)
283 val, collection, start_idx, thinning = vals
--> 284 val = body_fun(val)
285 idx = (i - start_idx) // thinning
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
170 # state is a tuple of size 1 - containing HMCState
--> 171 return sampler.sample(state[0], args, kwargs),
172
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
529 """
--> 530 return self._sample_fn(state, model_args, model_kwargs)
531
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
322 vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
--> 323 vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
324 hmc_state.adapt_state.inverse_mass_matrix,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length)
291
--> 292 binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
293 inverse_mass_matrix, step_size, rng_key,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
748 state = (tree, rng_key)
--> 749 tree, _ = while_loop(_cond_fn, _body_fn, state)
750 return tree
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
122 else:
--> 123 return lax.while_loop(cond_fun, body_fun, init_val)
124
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
742 going_right = random.bernoulli(direction_key)
--> 743 tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
744 going_right, doubling_key, energy_current, max_delta_energy,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
600
--> 601 new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
602 inverse_mass_matrix, step_size,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
687
--> 688 tree, turning, _, _, _ = while_loop(
689 _cond_fn,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
122 else:
--> 123 return lax.while_loop(cond_fun, body_fun, init_val)
124
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
659 z, r, z_grad = _get_leaf(current_tree, going_right)
--> 660 new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
661 going_right, energy_current, max_delta_energy)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
568 step_size = jnp.where(going_right, step_size, -step_size)
--> 569 z_new, r_new, potential_energy_new, z_new_grad = vv_update(
570 step_size,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in update_fn(step_size, inverse_mass_matrix, state)
229 z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1)
--> 230 potential_energy, z_grad = _value_and_grad(potential_fn, z, forward_mode_differentiation)
231 r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
183 else:
--> 184 return value_and_grad(f)(x)
185
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
162 # no param is needed for log_density computation because we already substitute
--> 163 log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
164 return - log_joint
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
58 else:
---> 59 log_prob = site['fn'].log_prob(value)
60
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/distributions/util.py in wrapper(self, *args, **kwargs)
554 def wrapper(self, *args, **kwargs):
--> 555 log_prob = log_prob_fn(self, *args, *kwargs)
556 if self._validate_args:
//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/jax/distribution.py in log_prob(self, value)
33 def log_prob(self, value):
---> 34 return self.gp.log_likelihood(value)
35
//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/core.py in log_likelihood(self, y, inplace)
426 y = self._process_input(y, require_vector=True, inplace=inplace)
--> 427 return self._norm - 0.5 * self._do_norm(y - self._mean_value)
428
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other)
5255 return NotImplemented
-> 5256 return binary_op(self, other)
5257 return deferring_binary_op
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2)
386 else:
--> 387 fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
388 if lax_doc:
FilteredStackTrace: jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Incompatible sublevel: DynamicJaxprTrace(level=2/2), (4, 1).
The tracer that caused this error was created on line //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__).
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/handlers.py:162 (get_trace)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
The function being traced when the tracer leaked was _body_fn at //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py:655.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
UnexpectedTracerError Traceback (most recent call last)
<ipython-input-5-8f42ec9b0711> in <module>
68 mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=2)
69 rng_key = random.PRNGKey(34923)
---> 70 mcmc.run(rng_key, t, yerr, y=y)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
500 else:
501 if self.chain_method == 'sequential':
--> 502 states, last_state = _laxmap(partial_map_fn, map_args)
503 elif self.chain_method == 'parallel':
504 states, last_state = pmap(partial_map_fn)(map_args)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
157 for i in range(n):
158 x = jit(_get_value_from_index)(xs, i)
--> 159 ys.append(f(x))
160
161 return tree_multimap(lambda *args: jnp.stack(args), *ys)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
341 collection_size = self._collection_params["collection_size"]
342 collection_size = collection_size if collection_size is None else collection_size // self.thinning
--> 343 collect_vals = fori_collect(lower_idx,
344 upper_idx,
345 sample_fn,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
309 with tqdm.trange(upper) as t:
310 for i in t:
--> 311 vals = jit(_body_fn)(i, vals)
312 t.set_description(progbar_desc(i), refresh=False)
313 if diagnostics_fn:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
414 return cpp_jitted_f(*args, **kwargs)
415 else:
--> 416 return cpp_jitted_f(context, *args, **kwargs)
417 f_jitted._cpp_jitted_f = cpp_jitted_f
418
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in cache_miss(_, *args, **kwargs)
295 _check_arg(arg)
296 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 297 out_flat = xla.xla_call(
298 flat_fun,
299 *args_flat,
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
1392
1393 def bind(self, fun, *args, **params):
-> 1394 return call_bind(self, fun, *args, **params)
1395
1396 def process(self, trace, fun, tracers, params):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1383 tracers = map(top_trace.full_raise, args)
1384 with maybe_new_sublevel(top_trace):
-> 1385 outs = primitive.process(top_trace, fun, tracers, params)
1386 return map(full_lower, apply_todos(env_trace_todo(), outs))
1387
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1395
1396 def process(self, trace, fun, tracers, params):
-> 1397 return trace.process_call(self, fun, tracers, params)
1398
1399 def post_process(self, trace, out_tracers, params):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
623
624 def process_call(self, primitive, f, tracers, params):
--> 625 return primitive.impl(f, *tracers, **params)
626 process_map = process_call
627
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
584
585 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 586 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
587 *unsafe_map(arg_spec, args))
588 try:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
258 fun.populate_stores(stores)
259 else:
--> 260 ans = call(fun, *args)
261 cache[key] = (ans, fun.stores)
262
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
660 abstract_args, arg_devices = unzip2(arg_specs)
661 if config.omnistaging_enabled:
--> 662 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
663 if any(isinstance(c, core.Tracer) for c in consts):
664 raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1218 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1219 main.jaxpr_stack = () # type: ignore
-> 1220 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1221 del fun, main
1222 return jaxpr, out_avals, consts
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1198 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1199 in_tracers = map(trace.new_arg, in_avals)
-> 1200 ans = fun.call_wrapped(*in_tracers)
1201 out_tracers = map(trace.full_raise, ans)
1202 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in _body_fn(i, vals)
282 def _body_fn(i, vals):
283 val, collection, start_idx, thinning = vals
--> 284 val = body_fun(val)
285 idx = (i - start_idx) // thinning
286 collection = cond(idx >= 0,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
169 def _sample_fn_nojit_args(state, sampler, args, kwargs):
170 # state is a tuple of size 1 - containing HMCState
--> 171 return sampler.sample(state[0], args, kwargs),
172
173
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
528 :return: Next `state` after running HMC.
529 """
--> 530 return self._sample_fn(state, model_args, model_kwargs)
531
532
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
321 if hmc_state.r is None else hmc_state.r
322 vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
--> 323 vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
324 hmc_state.adapt_state.inverse_mass_matrix,
325 vv_state,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length)
290 _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
291
--> 292 binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
293 inverse_mass_matrix, step_size, rng_key,
294 max_delta_energy=max_delta_energy,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
747
748 state = (tree, rng_key)
--> 749 tree, _ = while_loop(_cond_fn, _body_fn, state)
750 return tree
751
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
121 return val
122 else:
--> 123 return lax.while_loop(cond_fun, body_fun, init_val)
124
125
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in while_loop(cond_fun, body_fun, init_val)
282 # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
283 # necessary, a second time with modified init values.
--> 284 init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
285 new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
286 if changed:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _create_jaxpr(init_val)
268 init_avals = tuple(_map(_abstractify, init_vals))
269 cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 270 body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
271 if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
272 msg = "cond_fun must return a boolean scalar, but got pytree {}."
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
196 return f(*args, **kwargs)
197 else:
--> 198 return cached(bool(config.x64_enabled), *args, **kwargs)
199
200 wrapper.cache_clear = cached.cache_clear
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
189 @functools.lru_cache(max_size)
190 def cached(_, *args, **kwargs):
--> 191 return f(*args, **kwargs)
192
193 @functools.wraps(f)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
71 @cache()
72 def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
---> 73 jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
74 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
75 return closed_jaxpr, consts, out_tree
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
196 return f(*args, **kwargs)
197 else:
--> 198 return cached(bool(config.x64_enabled), *args, **kwargs)
199
200 wrapper.cache_clear = cached.cache_clear
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
189 @functools.lru_cache(max_size)
190 def cached(_, *args, **kwargs):
--> 191 return f(*args, **kwargs)
192
193 @functools.wraps(f)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_open_jaxpr(fun, in_tree, in_avals)
66 def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
67 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
---> 68 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
69 return jaxpr, consts, out_tree()
70
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_dynamic(fun, in_avals)
1188 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1189 main.jaxpr_stack = () # type: ignore
-> 1190 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1191 del main, fun
1192 return jaxpr, out_avals, consts
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1198 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1199 in_tracers = map(trace.new_arg, in_avals)
-> 1200 ans = fun.call_wrapped(*in_tracers)
1201 out_tracers = map(trace.full_raise, ans)
1202 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
741 key, direction_key, doubling_key = random.split(key, 3)
742 going_right = random.bernoulli(direction_key)
--> 743 tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
744 going_right, doubling_key, energy_current, max_delta_energy,
745 r_ckpts, r_sum_ckpts)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
599 key, transition_key = random.split(rng_key)
600
--> 601 new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
602 inverse_mass_matrix, step_size,
603 going_right, key, energy_current, max_delta_energy,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
686 basetree = prototype_tree._replace(num_proposals=0)
687
--> 688 tree, turning, _, _, _ = while_loop(
689 _cond_fn,
690 _body_fn,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
121 return val
122 else:
--> 123 return lax.while_loop(cond_fun, body_fun, init_val)
124
125
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in while_loop(cond_fun, body_fun, init_val)
282 # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
283 # necessary, a second time with modified init values.
--> 284 init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
285 new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
286 if changed:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _create_jaxpr(init_val)
268 init_avals = tuple(_map(_abstractify, init_vals))
269 cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 270 body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
271 if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
272 msg = "cond_fun must return a boolean scalar, but got pytree {}."
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
196 return f(*args, **kwargs)
197 else:
--> 198 return cached(bool(config.x64_enabled), *args, **kwargs)
199
200 wrapper.cache_clear = cached.cache_clear
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
189 @functools.lru_cache(max_size)
190 def cached(_, *args, **kwargs):
--> 191 return f(*args, **kwargs)
192
193 @functools.wraps(f)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
71 @cache()
72 def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
---> 73 jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
74 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
75 return closed_jaxpr, consts, out_tree
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
196 return f(*args, **kwargs)
197 else:
--> 198 return cached(bool(config.x64_enabled), *args, **kwargs)
199
200 wrapper.cache_clear = cached.cache_clear
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
189 @functools.lru_cache(max_size)
190 def cached(_, *args, **kwargs):
--> 191 return f(*args, **kwargs)
192
193 @functools.wraps(f)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_open_jaxpr(fun, in_tree, in_avals)
66 def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
67 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
---> 68 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
69 return jaxpr, consts, out_tree()
70
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_dynamic(fun, in_avals)
1188 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1189 main.jaxpr_stack = () # type: ignore
-> 1190 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1191 del main, fun
1192 return jaxpr, out_avals, consts
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1198 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1199 in_tracers = map(trace.new_arg, in_avals)
-> 1200 ans = fun.call_wrapped(*in_tracers)
1201 out_tracers = map(trace.full_raise, ans)
1202 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
658 # If we are going to the right, start from the right leaf of the current tree.
659 z, r, z_grad = _get_leaf(current_tree, going_right)
--> 660 new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
661 going_right, energy_current, max_delta_energy)
662 new_tree = cond(current_tree.num_proposals == 0,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
567 energy_current, max_delta_energy):
568 step_size = jnp.where(going_right, step_size, -step_size)
--> 569 z_new, r_new, potential_energy_new, z_new_grad = vv_update(
570 step_size,
571 inverse_mass_matrix,
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in update_fn(step_size, inverse_mass_matrix, state)
228 r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r)
229 z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1)
--> 230 potential_energy, z_grad = _value_and_grad(potential_fn, z, forward_mode_differentiation)
231 r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1)
232 return IntegratorState(z, r, potential_energy, z_grad)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
182 return f(x), jacfwd(f)(x)
183 else:
--> 184 return value_and_grad(f)(x)
185
186
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
821 tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
822 if not has_aux:
--> 823 ans, vjp_py = _vjp(f_partial, *dyn_args)
824 else:
825 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in _vjp(fun, has_aux, *primals)
1894 if not has_aux:
1895 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1896 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1897 out_tree = out_tree()
1898 else:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
112 def vjp(traceable, primals, has_aux=False):
113 if not has_aux:
--> 114 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
115 else:
116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
99 _, in_tree = tree_flatten(((primals, primals), {}))
100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
103 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
504 with core.new_main(JaxprTrace) as main:
505 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 506 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
507 assert not env
508 del main, fun, env
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
161 substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
162 # no param is needed for log_density computation because we already substitute
--> 163 log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
164 return - log_joint
165
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
57 log_prob = site['fn'].log_prob(value, intermediates)
58 else:
---> 59 log_prob = site['fn'].log_prob(value)
60
61 if (scale is not None) and (not is_identically_one(scale)):
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/distributions/util.py in wrapper(self, *args, **kwargs)
553 def validate_sample(log_prob_fn):
554 def wrapper(self, *args, **kwargs):
--> 555 log_prob = log_prob_fn(self, *args, *kwargs)
556 if self._validate_args:
557 value = kwargs['value'] if 'value' in kwargs else args[0]
//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/jax/distribution.py in log_prob(self, value)
32 @dist.util.validate_sample
33 def log_prob(self, value):
---> 34 return self.gp.log_likelihood(value)
35
36 def sample(self, key, sample_shape=()):
//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/core.py in log_likelihood(self, y, inplace)
425 """
426 y = self._process_input(y, require_vector=True, inplace=inplace)
--> 427 return self._norm - 0.5 * self._do_norm(y - self._mean_value)
428
429 def predict(
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in __sub__(self, other)
521 def __add__(self, other): return self.aval._add(self, other)
522 def __radd__(self, other): return self.aval._radd(self, other)
--> 523 def __sub__(self, other): return self.aval._sub(self, other)
524 def __rsub__(self, other): return self.aval._rsub(self, other)
525 def __mul__(self, other): return self.aval._mul(self, other)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other)
5254 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
5255 return NotImplemented
-> 5256 return binary_op(self, other)
5257 return deferring_binary_op
5258
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2)
385 fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
386 else:
--> 387 fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
388 if lax_doc:
389 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/lax.py in sub(x, y)
343 def sub(x: Array, y: Array) -> Array:
344 r"""Elementwise subtraction: :math:`x - y`."""
--> 345 return sub_p.bind(x, y)
346
347 def mul(x: Array, y: Array) -> Array:
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
281 or valid_jaxtype(arg) for arg in args), args
282 top_trace = find_top_trace(args)
--> 283 tracers = map(top_trace.full_raise, args)
284 out = top_trace.process_primitive(self, tracers, params)
285 return map(full_lower, out) if self.multiple_results else full_lower(out)
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in safe_map(f, *args)
39 for arg in args[1:]:
40 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 41 return list(map(f, *args))
42
43 def unzip2(xys):
//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in full_raise(self, val)
402 elif val._trace.level < level:
403 if val._trace.sublevel > sublevel:
--> 404 raise escaped_tracer_error(
405 val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
406 return self.lift(val)
UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Incompatible sublevel: DynamicJaxprTrace(level=2/2), (4, 1).
The tracer that caused this error was created on line //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__).
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/handlers.py:162 (get_trace)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
The function being traced when the tracer leaked was _body_fn at //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py:655.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
\```
</p>
</details>
Hi there! I encountered an issue with macOS when I attempted to run celerite2 in my jupyter notebook.
Here's the line:
from celerite2.theano import terms, GaussianProcess
And the issue was:
The kernel appears to have died. It will restart automatically.
I couldn't solve it by restarting my computer and my notebook. I then searched online and found one tip that might work, which requires me to install the environment with conda. (I used Anaconda but installed celerite2 with pip.) However, I didn't find a way to install it with conda in the tutorial. Would it be convenient for you to add a way with conda?
Thanks a lot!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.