Coder Social home page Coder Social logo

celerite2's People

Contributors

dependabot[bot] avatar dfm avatar jacksonloper avatar pre-commit-ci[bot] avatar tovrstra avatar vandalt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

celerite2's Issues

NotImplementedError for differentiation of gp.predict

Hi dfm,

I am using gaussian process as an interpolator for some control points and I want to calculate the gradient of a function that takes the gp's predicted values as input. A simple example is the sum_interp_gp function below, where I simply add up the predicted values. It seems that differentiation of gp.predict has not been implemented. Are there plans to do this?

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/tmp/ipykernel_18624/620407306.py in <module>
----> 1 cf.grad_sum_interp_gp(params)

    [... skipping hidden 12 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in grad_sum_interp_gp(self, *args)
    144     @partial(jit, static_argnums=(0,))
    145     def grad_sum_interp_gp(self, *args):
--> 146         return grad(self.sum_interp_gp)(*args)
    147 
    148     @partial(jit, static_argnums=(0,))

    [... skipping hidden 26 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in sum_interp_gp(self, *args)
    139     @partial(jit, static_argnums=(0,))
    140     def sum_interp_gp(self, *args):
--> 141         mu = self.interp_gp(self.annual, *args)
    142         return jnp.sum(mu)
    143 

    [... skipping hidden 17 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.3-py3.8.egg/ticktack/fitting.py in interp_gp(self, tval, *args)
    133         gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
    134         gp.compute(self.control_points_time)
--> 135         mu = gp.predict(control_points, t=tval, return_var=False)
    136         mu = (tval > self.start) * mu +  (tval <= self.start) * mean
    137         return mu

~/.local/lib/python3.8/site-packages/celerite2/core.py in predict(self, y, t, return_cov, return_var, include_mean, kernel)
    472         if return_cov:
    473             return cond.mean, cond.covariance
--> 474         return cond.mean
    475 
    476     def condition(self, y, t=None, *, include_mean=True, kernel=None):

~/.local/lib/python3.8/site-packages/celerite2/core.py in mean(self)
    125 
    126         mu = self.gp._zeros_like(self._xs)
--> 127         mu = self._do_dot(alpha, mu)
    128 
    129         if self.include_mean:

~/.local/lib/python3.8/site-packages/celerite2/core.py in _do_dot(self, inp, target)
    106             target = target[:, None]
    107 
--> 108         target = self._do_general_matmul(c, U1, V1, U2, V2, inp, target)
    109 
    110         if is_vector:

~/.local/lib/python3.8/site-packages/celerite2/jax/celerite2.py in _do_general_matmul(self, c, U1, V1, U2, V2, inp, target)
     10 class ConditionalDistribution(BaseConditionalDistribution):
     11     def _do_general_matmul(self, c, U1, V1, U2, V2, inp, target):
---> 12         target += ops.general_matmul_lower(
     13             self._xs, self.gp._t, c, U2, V1, inp
     14         )

~/.local/lib/python3.8/site-packages/celerite2/jax/ops.py in general_matmul_lower(t1, t2, c, U, V, Y)
     61 
     62 def general_matmul_lower(t1, t2, c, U, V, Y):
---> 63     Z, F = general_matmul_lower_p.bind(t1, t2, c, U, V, Y)
     64     return Z
     65 

    [... skipping hidden 1 frame]

~/.local/lib/python3.8/site-packages/jax/interpreters/ad.py in process_primitive(self, primitive, tracers, params)
    280     if not jvp:
    281       msg = f"Differentiation rule for '{primitive}' not implemented"
--> 282       raise NotImplementedError(msg)
    283     primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
    284     if primitive.multiple_results:

NotImplementedError: Differentiation rule for 'celerite2_general_matmul_lower' not implemented

SHOTerm not working in numpyro model

Hi @dfm,

I'm trying to use numpyro to sample a GP with the SHO kernel as follows:

from jax.config import config
config.update('jax_enable_x64', True)

import jax
import jax.numpy as jnp
from celerite2.jax import GaussianProcess, terms

import numpyro.distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS

prior_sigma = 1.0

def numpyro_model(x, yerr, y=None):

    mean = sample("mean", dist.Normal(0.0, prior_sigma))
    logjitter = sample("logjitter", dist.Normal(-26, 3 * prior_sigma))

    logsigma = sample("logsigma", dist.Normal(-11, 3 * prior_sigma))
    rho = sample("rho", dist.Normal(1.0, 3 * prior_sigma))
    tau = sample("tau", dist.Normal(0.1, prior_sigma))
        
    term = terms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
    gp = GaussianProcess(term, mean=mean)
    gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

    sample("obs", gp.numpyro_dist(), obs=y)
    
nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=True,
)
rng_key = jax.random.PRNGKey(34923)
yerr = 1e-8
mcmc.run(rng_key, x, yerr, y=y)

and I'm getting an error with a long traceback that ends:

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
    397     over = OverdampedSHOTerm(*args, **kwargs)
    398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
    400         return over
    401     return under

    [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
Full traceback
---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-139-60f93f4ec4d5> in <module>
    1 yerr = 1e-8
----> 2 mcmc.run(rng_key, x, yerr, y=y)
    3 samples = mcmc.get_samples()

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
  596         else:
  597             if self.chain_method == "sequential":
--> 598                 states, last_state = _laxmap(partial_map_fn, map_args)
  599             elif self.chain_method == "parallel":
  600                 states, last_state = pmap(partial_map_fn)(map_args)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
  158     for i in range(n):
  159         x = jit(_get_value_from_index)(xs, i)
--> 160         ys.append(f(x))
  161 
  162     return tree_map(lambda *args: jnp.stack(args), *ys)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
  379         rng_key, init_state, init_params = init
  380         if init_state is None:
--> 381             init_state = self.sampler.init(
  382                 rng_key,
  383                 self.num_warmup,

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
  704                 vmap(random.split)(rng_key), 0, 1
  705             )
--> 706         init_params = self._init_state(
  707             rng_key_init_model, model_args, model_kwargs, init_params
  708         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
  650     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
  651         if self._model is not None:
--> 652             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  653                 rng_key,
  654                 self._model,

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
  654         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
  655     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 656     (init_params, pe, grad), is_valid = find_valid_initial_params(
  657         rng_key,
  658         substitute(

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
  395     # Handle possible vectorization
  396     if rng_key.ndim == 1:
--> 397         (init_params, pe, z_grad), is_valid = _find_valid_params(
  398             rng_key, exit_early=True
  399         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
  388         # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
  389         # even if the init_state is a valid result
--> 390         _, _, (init_params, pe, z_grad), is_valid = while_loop(
  391             cond_fn, body_fn, init_state
  392         )

/usr/local/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
  129         return val
  130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
  132 
  133 

  [... skipping hidden 9 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
  365                 z_grad = jacfwd(potential_fn)(params)
  366             else:
--> 367                 pe, z_grad = value_and_grad(potential_fn)(params)
  368             z_grad_flat = ravel_pytree(z_grad)[0]
  369             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

  [... skipping hidden 8 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
  247     )
  248     # no param is needed for log_density computation because we already substitute
--> 249     log_joint, model_trace = log_density_(
  250         substituted_model, model_args, model_kwargs, {}
  251     )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
   60     """
   61     model = substitute(model, data=params)
---> 62     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
   63     log_joint = jnp.zeros(())
   64     for site in model_trace.values():

/usr/local/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
  169         :return: `OrderedDict` containing the execution trace.
  170         """
--> 171         self(*args, **kwargs)
  172         return self.trace
  173 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

<ipython-input-137-39666cc8f7df> in numpyro_model(x, yerr, y)
   10     tau = sample("tau", dist.Normal(0.1, prior_sigma))
   11 
---> 12     term = jTerms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
   13     gp = jGP(term, mean=mean)
   14     gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
  397     over = OverdampedSHOTerm(*args, **kwargs)
  398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
  400         return over
  401     return under

  [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
The error occurred while tracing the function body_fn at /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:315 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

It runs just fine if I replace terms.SHOTerm with terms.UnderdampedSHOTerm and constrain the hyper parameters to be in the underdamped regime. Any idea what's going on here?

Thanks!

Sampling with a non-constant mean

Hi there,

I'm trying to model a light curve (~10^5 points) using celerite2 and PyMC3 as a SHOTerm + a mean model composed of a sum of sinusoids (where I have fairly tight priors on frequencies + amplitudes). If I'm reading the documentation for GaussianProcess correctly, it looks like the mean parameter needs to either be a scalar, or an object that can be called with a vector of time values.

Assuming I have that correct, I've implemented this model as:

import pymc3 as pm
import pymc3_ext as pmx
import aesara_theano_fallback.tensor as tt
from celerite2.theano import terms, GaussianProcess

with pm.Model() as model:
    # A jitter term describing excess white noise
    log_jitter = pm.Normal("log_jitter", mu=np.log(err_array.mean()), sigma=5.0)

    # SHOTerm
    logsigma = pm.Uniform("log_sigma", lower=-15, upper=15)
    sigma = pm.Deterministic("sigma",tt.exp(log_sigma))
    logtau = pm.Uniform("log_tau", lower=-5, upper=5)
    tau = pm.Deterministic("tau", tt.exp(logtau))
    
    #We want Q<1/sqrt(2)
    Q = pm.Uniform("Q",lower=0.0,upper=1/np.sqrt(2.0)+0.00001, testval=1/np.sqrt(2.0)) 
    rho = pm.Deterministic("rho", tau * np.pi / Q)

    kernel = terms.SHOTerm(sigma=sigma, rho=rho, tau=tau)
   
    #now the non-constant mean
    #the mean flux of the light curve
    mean_flux = pm.Normal("mean_flux", mu = 1.0, sigma=np.std(flux))

    f, f_err = lists_of_frequencies_and_errorbars
    a, a_err = lists_of_amplitudes_and_errorbars

    #Make them PyMC3 variables
    fs = [pm.Uniform(f"f{i}", lower = f[i] - 3*f_err[i], upper=f[i]+ 3*f_err[i]) for i in range(len(f)]
    amps = [pm.Uniform(f"a{i}", lower = a[i] - 3*a_err[i], upper=a[i]+ 3*a_err[i]) for i in range(len(f)]
    phases = [pmx.Angle(f"phi{i}") for i in range(len(f))]

    #Making a callable for celerite2
    mean = lambda x: tt.sum([a * tt.sin(2.0*np.pi*f*x + phi) for a,f,phi in zip(amps,fs,phases)],axis=0) + mean_flux
    #And add it to the model to we can track it
    pm.Deterministic("mean", mean(time_array))

    gp = GaussianProcess(
        kernel,
        t=time_array,
        diag=err_array ** 2.0 + tt.exp(2 * log_jitter),
        mean=mean,
        quiet=True,
    )

    # Compute the Gaussian Process likelihood and add it into the PyMC3 model 
    gp.marginal("gp", observed=flux_array)

    # Compute the mean model prediction of just the GP
    pm.Deterministic("pred", gp.predict(flux_array, include_mean=False))

When I run pmx.optimize(), this seems to go fairly well, taking ~45 seconds on my laptop with a 10-frequency mean model (35 parameters in total), but when I try to sample using pmx.sample with more than one core, I get the warning message:

Could not pickle model, sampling singlethreaded.
Sequential sampling (4 chains in 1 job)

I don't get this message when I use a constant mean, so my guess is it has something to do with the callable that gets passed to GaussianProcess. The model takes quite a while to sample, so I'd really like to run this in parallel.

Is there a workaround? Am I perhaps completely misreading the GaussianProcess docstring? Oh and for completeness, running

from aesara_theano_fallback import __version__ as tt_version
from celerite2 import __version__ as c2_version

pm.__version__, pmx.__version__, tt_version, c2_version

yields

'3.11.4', '0.1.0', '0.0.4', '0.2.0'

Unable to access pymc3 submodule

Hi folks, I just discovered this updated version of celerite and am excited to add celerite2's PyMC3 GP into the Eureka! package (as I'm finding more and more cases where using a GP would be beneficial). We already have the original celerite GP working (at least kinda) as well as the george GP for standard python minimizers and samplers (emcee, scipy.optimize.minimize, dynesty), but we also have a PyMC3 version of our fitting code which allows one to use starry (ideal for eclipse mapping) and in general PyMC3's NUTS sampler has allowed for much faster fits (for me at least). I know that jax and PyMC4 offer advantages over the deprecated PyMC3 and theano implementation, but over the past year I've already spent more than a fifty hours on getting PyMC3 versions of all our astrophysical and systematic models implemented, I need to be able to use starry for the astrophysical model, and I'm looking for a fairly quick way to get GPs implemented. I was about to try PyMC3's built-in GP, but I saw on the old exoplanet docs that exoplanet had a faster GP implementation for PyMC3 sampling and then noticed that the code had been migrated here to celerite2.

However, by default I am unable to access the pymc3 submodule of celerite2 that is mentioned in the documentation. I've tried installing the version on main branch on GitHub (seeing that v0.2.1 doesn't have pymc3 code), but I get an error that 'celerite2' has no attribute 'pymc3':

> pip install celerite2[pymc3]@git+https://github.com/exoplanet-dev/celerite2
> ipython
Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:19) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.14.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import celerite2
c
In [2]: celerite2.__version__
Out[2]: '0.3.0rc2.dev11+g74b0705'

In [3]: celerite2.pymc3
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 1
----> 1 celerite2.pymc3

AttributeError: module 'celerite2' has no attribute 'pymc3'

Looking at the package on GitHub I see the pymc3 folder under python/celerite2/pymc3, but it seems it isn't being imported in the __init__.py file. Just adding from celerite2 import pymc3 to the python/celerite2/__init__.py file wasn't enough to solve the problem either, giving me the following error message when installing:

Building wheels for collected packages: celerite2
  Building wheel for celerite2 (pyproject.toml) ... error
  error: subprocess-exited-with-error
  
  ร— Building wheel for celerite2 (pyproject.toml) did not run successfully.
  โ”‚ exit code: 1
  โ•ฐโ”€> [27 lines of output]
      running bdist_wheel
      running build
      running build_py
      copying python/celerite2/__init__.py -> build/lib.macosx-10.9-x86_64-cpython-39/celerite2
      copying python/celerite2/celerite2_version.py -> build/lib.macosx-10.9-x86_64-cpython-39/celerite2
      running egg_info
      writing python/celerite2.egg-info/PKG-INFO
      writing dependency_links to python/celerite2.egg-info/dependency_links.txt
      writing requirements to python/celerite2.egg-info/requires.txt
      writing top-level names to python/celerite2.egg-info/top_level.txt
      reading manifest template 'MANIFEST.in'
      warning: no directories found matching 'c++/vendor/eigen/Eigen'
      adding license file 'LICENSE'
      writing manifest file 'python/celerite2.egg-info/SOURCES.txt'
      running build_ext
      clang -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -I/Users/tjbell1/miniconda3/envs/eureka_starry/include/python3.9 -c flagcheck.cpp -o flagcheck.o -std=c++17
      building 'celerite2.driver' extension
      clang -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -fPIC -O2 -isystem /Users/tjbell1/miniconda3/envs/eureka_starry/include -Ic++/include -Ic++/vendor/eigen -Ipython/celerite2 -I/private/var/folders/f3/_h5lxm511fv9sjb5d3xrj0jm0000gq/T/pip-build-env-u9cav0rq/overlay/lib/python3.9/site-packages/pybind11/include -I/Users/tjbell1/miniconda3/envs/eureka_starry/include/python3.9 -c python/celerite2/driver.cpp -o build/temp.macosx-10.9-x86_64-cpython-39/python/celerite2/driver.o -std=c++17 -mmacosx-version-min=10.14 -fvisibility=hidden -g0
      In file included from python/celerite2/driver.cpp:6:
      In file included from python/celerite2/driver.hpp:8:
      In file included from c++/include/celerite2/celerite2.h:4:
      In file included from c++/include/celerite2/core.hpp:4:
      c++/include/celerite2/forward.hpp:4:10: fatal error: 'Eigen/Core' file not found
      #include <Eigen/Core>
               ^~~~~~~~~~~~
      1 error generated.
      error: command '/usr/bin/clang' failed with exit code 1
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for celerite2
Failed to build celerite2
ERROR: Could not build wheels for celerite2, which is required to install pyproject.toml-based projects

I was able to find the referenced Eigen package at https://eigen.tuxfamily.org/, and I downloaded the latest version and put the entire contents of that package in celerite2's (previously empty) c++/vendor/eigen folder. With that package downloaded and my edit to the __init__.py file, I'm now able to import and use celerite2.pymc3.

I thought I'd document my troubleshooting process here to help anyone else trying to do the same thing or to help the devs figure out what changes need to be made to the documentation and/or package.

Derive linear algorithm for conditional variance

I think that this must be possible using a forward and backward pass. At the very least, this must be possible when the predictive kernel is the same as the base kernel (by comparison with state-space models), but I feel like it must be possible for general semi-separable matrices.

Jax mean models?

Hi Dan,

Just a question - there is a great tutorial on using the numpyro backend with Jax.

Is there any documentation on using jax mean models too? eg any pitfalls with definitions, gradients, jit..

All the best,

Ben

Further discussion on choice of priors for GP hyperparameters

Is your feature request related to a problem? Please describe.
Hi, could you discuss the different priors you use for the GP hyperparameters, be it the SHO kernel for the transit case studies, or the rotation kernel in the stellar variability case. They tend to differ in terms of the functional form as well as starting value. It would be useful to have some discussion on this.

ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import)

I am happily using celerite2 with no issues on my local computer, but encountered a problem when importing celerite2 on readthedocs. I found a workaround but am posting here in case anyone has the same issue.

Full Read The Docs traceback Running Sphinx v7.2.6 1.12.5

Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/cmd/build.py", line 293, in build_main
app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/application.py", line 211, in init
self.config = Config.read(self.confdir, confoverrides or {}, self.tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 181, in read
namespace = eval_config_file(filename, tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 371, in eval_config_file
raise ConfigError(msg % traceback.format_exc()) from exc
sphinx.errors.ConfigError: There is a programmable error in your configuration file:

Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)

Configuration error:
There is a programmable error in your configuration file:

Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file
exec(code, namespace) # NoQA: S102
^^^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/docs/conf.py", line 24, in
import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/init.py", line 7, in
from . import jtow
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/checkouts/latest/jtow/jtow.py", line 53, in
from tshirt.pipeline import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/init.py", line 2, in
from .pipeline import phot_pipeline, spec_pipeline, analysis, prep_images
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/init.py", line 1, in
from . import phot_pipeline
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/phot_pipeline.py", line 51, in
from .instrument_specific import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/init.py", line 3, in
from . import rowamp_sub
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/tshirt/pipeline/instrument_specific/rowamp_sub.py", line 16, in
import celerite2
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py", line 4, in
from . import terms
File "/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/terms.py", line 22, in
from . import driver
ImportError: cannot import name 'driver' from partially initialized module 'celerite2' (most likely due to a circular import) (/home/docs/checkouts/readthedocs.org/user_builds/jtow/envs/latest/lib/python3.11/site-packages/celerite2/init.py)

More information is available from this failed build:
https://readthedocs.org/api/v2/build/22591584.txt

This was with Python 3.11

GP mean prediction as the mean function in another GP

Hi Dan and celerite team

I have the following versions:

exoplanet.__version__ = '0.6.0'
celerite2.__version__ = '0.3.1'
pymc.__version__ = '5.10.4'

The goal

I want to model in-transit spot occultations with celerite, and the GP should only be conditioned on in-transit data. A separate GP with a different kernel and hyperparameters is used to describe a longer timescale correlated signal. I assume they have to be separate GPs because the data will be different lengths.

Expected result

For a single GP (gp1), I followed the exoplanet examples and created a function that outputs the transit light curve and pass that to the mean keyword in celerite2.pymc.GaussianProcess. For a second GP (gp2) my expectation was that I could create a mean function that is a sum of a light curve model and gp1.predict(). However this doesn't seem to work, it throws AttributeError: '_CeleriteOp' object has no attribute 'rev_op'. If I add an .eval() it'll run but the maximum likelihood model isn't a good fit, I suspect something goes wrong behind the scenes.

Is it possible to use the mean of a GP as a mean function in another GP?

Here's what I'm working with so far:

import numpy as np
import matplotlib.pyplot as plt
from functools import partial

import pymc as pm
import pymc_ext as pmx
import exoplanet as xo
import pytensor.tensor as pt
from celerite2.pymc import GaussianProcess, terms

np.random.seed(123)
period = np.random.uniform(3,10)
t = np.arange(-0.2, 0.2, 2/60/24)

# The light curve calculation requires an orbit
orbit = xo.orbits.KeplerianOrbit(period=period, t0=0, b=0, duration=0.15, ror=0.1)

# Compute a limb-darkened light curve using starry
u = [0.3, 0.2]
light_curve = np.sum(
    xo.LimbDarkLightCurve(u[0], u[1])
    .get_light_curve(orbit=orbit, r=0.1, t=t, texp=2/60/24)
    .eval(),
    axis=-1
)

# Create simulated data
yerr = 3e-4
y = light_curve
M = (t > -0.5*0.15) & (t < 0.5*0.15) # transit mask
y += yerr * np.random.randn(len(y)) # add noise
y += 0.01*t # add linear term
y += 1

# add some spot occultations
locs = [-0.005, 0.03]
widths = [0.008, 0.01]
amps = [0.002, 0.001]

for i in range(len(locs)):
    m = (t > (locs[i]-widths[i])) & (t < (locs[i]+widths[i]))
    y[m] += amps[i] * np.exp(-(t[m]-locs[i])**2/widths[i]**2)


with pm.Model() as model:
    mean = pm.Normal("mean", mu=1, sigma=0.002, initval=1)

    # The time of a reference transit for each planet
    t0 = pm.Normal("t0", mu=0, sigma=0.01, initval=0)

    u = xo.quad_limb_dark("u", initval=[0.3, 0.2])

    log_dur = pm.Normal("log_dur", mu=np.log(0.13), sigma=0.1, initval=np.log(0.13))
    dur = pm.Deterministic("dur", pt.exp(log_dur))

    log_ror = pm.Normal("logr", mu=np.log(0.1), sigma=0.1, initval=np.log(0.1))
    ror = pm.Deterministic("r", pt.exp(log_ror))
    
    b = xo.impact_parameter("b", ror=ror, initval=0.3)

    star = xo.LimbDarkLightCurve(u[0], u[1])

    # Set up a Keplerian orbit for the planets
    orbit = xo.orbits.KeplerianOrbit(period=period, t0=t0, b=b, duration=dur, ror=ror)

    # Compute the model light curve using starry
    def _mean_fn(orbit, mean, r, star, t):
        return pt.sum(star.get_light_curve(
        orbit=orbit, r=r, t=t, texp=2/60/24),
        axis=-1
        ) + mean
    mean_fn = partial(_mean_fn, orbit, mean, ror, star)
    pm.Deterministic("light_curves", mean_fn(t))

    # GP parameters for the linear trend and white noise
    log_sigma = pm.Normal("log_sigma", mu=np.log(0.5*yerr), sigma=0.1)
    sigma = pm.Deterministic("sigma", pt.exp(log_sigma))

    log_rho_gp = pm.Normal("log_rho_gp", mu=7, sigma=0.5, initval=7)
    rho_gp = pm.Deterministic("rho_gp", pt.exp(log_rho_gp))

    log_sigma_gp = pm.Normal("log_sigma_gp", mu=-4, sigma=0.5, initval=-4)
    sigma_gp = pm.Deterministic("sigma_gp", pt.exp(log_sigma_gp))

    kernel = terms.Matern32Term(rho=rho_gp, sigma=sigma_gp)

    gp = GaussianProcess(kernel, t=t, diag=yerr**2 + sigma**2,
                         mean=mean_fn, quiet=True)
    pm.Deterministic("gp_preds", gp.predict(y, include_mean=False))
    
    gp.marginal("obs", observed=y)

    ######################################################################
    # problematic part
    ###################################################################### 
    # GP parameters for spot occultations
    log_sigma_spot = pm.Normal("log_sigma_spot", mu=-10, sigma=5, initval=-10)
    sigma_spot = pm.Deterministic("sigma_spot", pt.exp(log_sigma_spot))

    log_rho_spot = pm.Normal("log_rho_spot", mu=np.log(0.02), sigma=0.5)
    rho_spot = pm.Deterministic("rho_spot", pt.exp(log_rho_spot))

    kernel2 = terms.Matern32Term(rho=rho_spot, sigma=sigma_spot)

    def _mean_fn_spot(gp, orbit, star, mean, y, r, t):

        gp_pred = gp.predict(y, t=t, include_mean=False).eval()
        lc_pred = (pt.sum(star.get_light_curve(
                orbit=orbit, r=r, t=t, texp=2/60/24),
                axis=-1
                ) + mean).eval()
        return pt.as_tensor_variable(lc_pred+gp_pred)

    spot_fn = partial(_mean_fn_spot, gp, orbit, star, mean, y, ror)

    gp2 = GaussianProcess(kernel2, t=t[M], diag=yerr**2 + sigma**2,
                         mean=spot_fn, quiet=False)
    pm.Deterministic("gp_preds_spot", gp2.predict(y[M], include_mean=False))
    
    gp2.marginal("obs_spot", observed=y[M])
    ###################################################################### 

    map_soln = pmx.optimize(start=model.initial_point())


# plot fit
spot_model = np.zeros_like(t)
spot_model[M] = map_soln["gp_preds_spot"]
full_mod = map_soln["light_curves"]+map_soln["gp_preds"]+spot_model

plt.figure()
plt.plot(t, y, ".k", ms=4, label="data")
plt.plot(t, full_mod, lw=1, label="full model")
plt.plot(t, map_soln["light_curves"], lw=1, ls='--', label="transit")
plt.plot(t, map_soln["gp_preds"]+map_soln["mean"], lw=1, ls=':', label="trend")
plt.plot(t, spot_model+map_soln["mean"], lw=1, ls='-.', label="spot")
plt.legend()

plt.xlim(t.min(), t.max())
plt.ylabel("relative flux")
plt.xlabel("time [days]")
plt.legend(fontsize=10)
_ = plt.title("map model")
plt.show()

Term.get_coefficients fails for jax implementation when built from source

Hi @dfm,

Today I've built celerite2 from source following the recommendations on the install docs. I'm trying to do something simple, like this

from celerite2.jax import terms

sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()

but I'm getting the following error

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [1], line 4
      1 from celerite2.jax import terms
      3 sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
----> 4 sho.get_coefficients()

File ~/git/celerite2/python/celerite2/jax/terms.py:36, in Term.get_coefficients(self)
     35 def get_coefficients(self):
---> 36     raise NotImplementedError("subclasses must implement this method")

NotImplementedError: subclasses must implement this method

At first I thought this could be an accident of the multiple SHOTerm implementations, for example, here

def SHOTerm(*args, **kwargs):
over = OverdampedSHOTerm(*args, **kwargs)
under = UnderdampedSHOTerm(*args, **kwargs)
if over.Q < 0.5:
return over
return under

and here

class SHOTerm(Term):

but commenting the first one out doesn't solve the problem.

Any ideas? Thanks!

Getting error while installation.

Hi, I was trying to install celerite2 from source code following the instructions in the documentation. But while testing the code I got an error:

python -m pytest -v python/test                                                                                                                                                      โ”€โ•ฏ
=================================================================================== test session starts ===================================================================================
platform darwin -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0 -- /opt/anaconda3/envs/pymc/bin/python
cachedir: .pytest_cache
rootdir: /Users/ajaysharma/celerite2
configfile: pyproject.toml
plugins: cov-4.1.0
collected 0 items / 1 error                                                                                                                                                               

========================================================================================= ERRORS ==========================================================================================
______________________________________________________________________________ ERROR collecting test session ______________________________________________________________________________
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/runner.py:341: in from_call
    result: Optional[TResult] = func()
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/runner.py:372: in <lambda>
    call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:749: in collect
    for x in self._collectfile(path):
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:588: in _collectfile
    ihook = self.gethookproxy(fspath)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/main.py:555: in gethookproxy
    my_conftestmodules = pm._getconftestmodules(
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:609: in _getconftestmodules
    mod = self._importconftest(conftestpath, importmode, rootpath)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:657: in _importconftest
    self.consider_conftest(mod)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:739: in consider_conftest
    self.register(conftestmodule, name=conftestmodule.__file__)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/_pytest/config/__init__.py:491: in register
    ret: Optional[str] = super().register(plugin, name)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_manager.py:164: in register
    hook._maybe_apply_history(hookimpl)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_hooks.py:559: in _maybe_apply_history
    res = self._hookexec(self.name, [method], kwargs, False)
/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/pluggy/_manager.py:115: in _hookexec
    return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
python/test/pymc/conftest.py:13: in pytest_configure
    pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing"
E   AttributeError: 'PyTensorConfigParser' object has no attribute 'gcc'
================================================================================= short test summary info =================================================================================
ERROR  - AttributeError: 'PyTensorConfigParser' object has no attribute 'gcc'
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
==================================================

Please help me to resolve this error. Any help would be appreciated.

Sys info. :
Mac OS M1 chip - version 14.1.1
Python - latest version

Derivative models

I've implemented a first pass at derivative models but there are some open questions so I'm going to remove it from #19 and just post it here for future work:

class _DerivativeHelperTerm(terms.Term):
    def __init__(self, term):
        self.term = term

    def get_coefficients(self):
        coeffs = self.term.get_coefficients()
        a, b, c, d = coeffs[2:]
        final_coeffs = [
            -coeffs[0] * coeffs[1],
            coeffs[1],
            b * d - a * c,
            -(a * d + b * c),
            c,
            d,
        ]
        return final_coeffs


class DerivativeLatentTerm(LatentTerm):
    def __init__(self, term, *, value_amplitude, gradient_amplitude):
        if term.__requires_general_addition__:
            raise TypeError(
                "You cannot perform operations on a term that requires general"
                " addition, it must be the outer term in the kernel"
            )

        val_amp = np.atleast_1d(value_amplitude)
        grad_amp = np.atleast_1d(gradient_amplitude)
        if val_amp.ndim != 1 or val_amp.shape != grad_amp.shape:
            raise ValueError("Dimension mismatch between amplitudes")
        amp = np.stack((val_amp, grad_amp), axis=-1)
        M = amp.shape[0]

        ar, cr, ac, bc, cc, dc = term.get_coefficients()

        Jr = len(cr)
        Jc = len(cc)
        J = Jr + 2 * Jc

        left = np.zeros((2, 4 * J))
        right = np.zeros((2, 4 * J))
        if Jr:
            left[0, :Jr] = 1.0
            left[0, 2 * Jr : 3 * Jr] = -1.0
            left[1, Jr : 2 * Jr] = 1.0
            left[1, 3 * Jr : 4 * Jr] = 1.0
            right[0, : 2 * Jr] = 1.0
            right[1, 2 * Jr : 4 * Jr] = 1.0

        if Jc:
            J0 = 4 * Jr
            for J0 in [4 * Jr, 4 * Jr + 4 * Jc]:
                left[0, J0 : J0 + Jc] = 1.0
                left[0, J0 + 2 * Jc : J0 + 3 * Jc] = -1.0
                left[1, J0 + Jc : J0 + 2 * Jc] = 1.0
                left[1, J0 + 3 * Jc : J0 + 4 * Jc] = 1.0
                right[0, J0 : J0 + 2 * Jc] = 1.0
                right[1, J0 + 2 * Jc : J0 + 4 * Jc] = 1.0

        self.left = np.dot(amp, left)[:, :, None]
        self.right = np.dot(amp, right)[:, :, None]

        # self.left = np.empty((M, 4 * J, 1))
        # self.left[:] = np.nan
        # self.left[:] = np.dot(amp, left)[:, :, None]

        # self.right = np.zeros((M, 4 * J, M))
        # right = np.dot(amp, right)
        # for m in range(M):
        #     self.right[m, :, m] = right[m]

        # self.left = np.reshape(self.left, (M, -1, 1))
        # self.right = np.reshape(self.right, (M, -1, 1))

        print(self.left)
        print(self.right)
        # print(self.right)
        # assert 0

        dterm = _DerivativeHelperTerm(term)
        d2term = terms.TermDiff(term)
        super().__init__(
            terms.TermSum(term, dterm, dterm, d2term),
            dimension=2,
            left_latent=self._left_latent,
            right_latent=self._right_latent,
        )

    def _left_latent(self, t, inds):
        return self.left[inds]

    def _right_latent(self, t, inds):
        return self.right[inds]

Fourier conventions for the PSD in the documentation of the SHOTerm

My apologies for this rather silly question. I'm just learning the basics at the moment. (The following example from celerite 1 already clarified things: https://celerite.readthedocs.io/en/stable/tutorials/normalization/)

The Fourier convention of the PSD is a bit confusing. I assume the factor $1/\sqrt{2\pi}$ is used in the definition of the Fourier transform. It would be helpful to mention this.

The point were it gets really confusing is the documentation of $S_0$, which seems to be inconsistent with the factor $\sqrt{2/\pi}$ in the docstring of the SHOTerm. According to the docstring, $S_0$ is "the power at $\omega=0$". Does the wording "the power at" have a special interpretation, other than $S(0)$?

It could also be helpful to mention the integrated autocorrelation time in the documentation of the SHOTerm, which is $2/Q \omega_0$, if I'm not mistaken.

Sorry for stumbling over these little things, and thanks for the nice implementation.

Avoiding LinAlgErrors for closesly sampled x arrays

Some background: I want to model a GP for variations in flux as a function of roll-angle. Most of our observations (think HST or CHEOPS) last for a few orbits. Each orbit shows an extremely self-similar (and often not particularly sinusoidal) variation due to e.g. Earth limb, lunar glint, temperature effects, etc. Hence, including a GP on roll-angle vs flux is perfect.

My method thus far has been to sort the data into increasing roll angle and then unsort afterwards to get the variation as a function of time. However, roll angle is not a continuous observation, and for some targets we have enough data that some measurements are extremely close in roll angle to others (<1e-4 degrees), yet not particularly correlated (as they're ~days apart in time). In these cases, celerite2 simply breaks - it throws a LinAlgError. I'm not sure of the maths, but likely because of a large difference between two extremely close points, although it seems to break even when the difference between such points are consistent with the variance.

Here is a working (well, non-working) example:

# Ten fake "orbits" with varying roll angle from ~50 to ~250deg
t=np.hstack([np.linspace(50+5*np.random.random(),250+5*np.random.random(),500) for c in range(10)])

with pm.Model():
    rollangle_w0=0.1 #NB - at higher rollangle_w0, we get a different error - a Theano Assert error
    rollangle_S0=120.0
    kern=theano_terms.SHOTerm(S0=rollangle_S0,  w0=rollangle_w0, Q=1/np.sqrt(2))#, mean = phot_mean)
    gp=celerite2.theano.GaussianProcess(kern, np.sort(t), mean=0.0) #Sorting to 
    gp.compute(t=t, diag=np.tile(0.25,len(t))**2)

Relatedly, even when celerite does not break outright, it is consistently forced into extremely short-timescale variations despite a strong prior against such over-fitting...

I think this is all because of an assumption that each x measurement in effectively instantaneous - in our case we have x values (roll angles) which are actually the average across some dx range which overlaps with the neighbouring values. I'm not sure it would be possible, but is there any way to incorporate a value (or array) of dx which would stop this error/overfitting?

Alternatively, is there some way to use a Periodic kernel to enforce similar variation across all points? I've tried adapted the RotationKernel, which works well for low-frequency variation on the order of 100degree wavelength, but is typically unable to model higher-frequency variation with wavelengths the order of 10s of degrees. But maybe there is a combination of a periodic term and e.g. a spikier Matern32 kernel?

Installation from source fails on centos 7 with gcc 4.8.5

Hi,

I'm trying to install from source following the directions here. The reason I'm installing from source rather than the pre-compiled binaries is that I'm hoping to try out the Kronecker functionality. It works fine on my mac, but when I try to install on the cluster I get the error:

ERROR: Command errored out with exit status -11: /gscratch/home/tagordon/jwst/celerite2/env/bin/python -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/gscratch/home/tagordon/jwst/celerite2/setup.py'"'"'; file='"'"'/gscratch/home/tagordon/jwst/celerite2/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(file);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, file, '"'"'exec'"'"'))' develop --no-deps Check the logs for full command output.

This happens regardless of whether I'm install from the kron branch or not. Any help is appreciated!

Thanks,
Tyler

Pointwise predictive accuracy

Hi @dfm,

I'm interested in doing model comparison with LOO-CV and/or WAIC via arviz. My model (w/ JAX+numpyro) relies on celerite2 for computing the log likelihood. The likelihood returned by celerite2 is a single value, but the pointwise predictive accuracy methods expect pointwise logp's. I think that's why when comparing models, I get warnings that say things like:

The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your 
model to make sure it returns element-wise logp.

But I'm not sure that a logp can be defined element-wise for a GP. So: is there a way to get element-wise logp's from a GP?

Thanks as always!

A documentation example for C++ API

Please, can you add a C++ example to the documentation? celerite C++ doc's example looks helpful but I'm not sure if it is applicable here. Thank you!

Small amplitude GP

Hi @dfm,

I've found a set of SHO kernel hyperparameters which imitate solar granulation by fitting two SHO kernels' PSDs to the measured PSD of the Sun. I'd like to then draw samples from the kernel with the best-fit PSD, like this:

import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
from celerite2 import terms, GaussianProcess

duration = 100 * u.d
cadence = 15 * u.s

params = np.array(
    [[2.98095799e+03, 1.88495559e+02, 1.00000000e-02],
     [4.97870684e-02, 5.02654825e+04, 1.00000000e-01]]
)

kernel = (
    # Granulation terms
    terms.SHOTerm(S0=params[0, 0], w0=params[0, 1], Q=params[0, 2]) + 
    terms.SHOTerm(S0=params[1, 0], w0=params[1, 1], Q=params[1, 2])
)

times = np.arange(0, duration.to(u.s).value, cadence.to(u.s).value, dtype=np.float64) * u.s
x = times.value

gp = GaussianProcess(kernel, t=x)

# Get samples with the kernel's PSD

y = gp.sample()

def power_spectrum(fluxes, d=60):
    """
    Compute the power spectrum of ``fluxes`` in units of [ppm^2 / microHz].

    Parameters
    ----------
    fluxes : `~numpy.ndarray`
        Fluxes with zero mean.
    d : float
        Time between samples [s].

    Returns
    -------
    freq : `~numpy.ndarray`
        Frequencies
    power : `~numpy.ndarray`
        Power at each frequency in units of [ppm^2 / microHz]
    """
    fft = np.fft.rfft(1e6 * fluxes)
    power = (fft * np.conj(fft)).real * 1e-6 / len(fluxes)
    freq = np.fft.rfftfreq(len(fluxes), d)
    return freq, power

# Get the power spectrum of the samples
freq, power = power_spectrum(y-1, d=cadence.to(u.s).value)
freq *= 1e6

# plot the power spectrum of samples, compare with kernel PSD
plt.loglog(freq, power, marker=',', lw=0)
plt.loglog(freq, kernel.get_psd(2*np.pi*freq)/2/np.pi, marker=',', lw=0)

download

But as you can see in the plot above, it looks like the kernel is just generating white noise, even though the kernel PSD is well behaved. Could this be a precision issue? Maybe I need to set eps to something?

This behavior crept up when switching from celerite to celerite2, and I think it has to do with the scales of the hyperparameters since they were each defined in log-space before and they're not anymore in celerite2.

Thanks!

RotationTerm translated to celerite (set_parameter_vector() input parameters, question).

When we initialize the doubleSHO, a.k.a. RotationTerm in celrite2, we need (*, sigma, period, Q0, dQ, f). I understand that the idea is to pass five parameters that are will be converted to S1, w1, Q1, S2, w2, Q2 needed for each SHO kernel. So, far so good!

However, I am trying to do is to build this kernel in celerite. See here, under "double_SHOTerm()"

https://github.com/3fon3fonov/exostriker/blob/exostriker-ready/exostriker/lib/RV_mod/GP_kernels.py

My doubt is that use inside the opt.func/mcmc/nested sampling one needs to continuously update the GP while it samples, and compute()

...
gps.set_parameter_vector([sampled GP params])
gps.compute()
...

however, it seems that set_parameter_vector() seems to expect six parameters, these being the 2xSHO parameters, is that right?

So to update the GP model I need six parameters (which BTW, must always be the individual 2xSHOTerm() in log!), but I sample five, which are; sigma, period, Q0, dQ, and f. If I understand correctly, the way how is done now, inside the loop/sample run I must always initialize the GP and then compute().

I wrote a simple function


def init_dSHOKernel(params):
    kernel =GP_kernels.double_SHOTerm(
            sigma_dSHO=params[0],
            period_dSHO=params[1],
            Q0_dSHO=params[2],
            dQ_dSHO=params[3], 
            f_dSHO=params[4])


    gps = celerite.GP(kernel, mean=0.0)
    return gps
...
#and then during the sampling:
gps = init_dSHOKernel(param_vect)
gps.compute()

This seems to work but is quite a compromise, I think. Any comments on this will be highly appreciated. For example, how is this done in celerite2? I will eventually migrate to celerite2, but I would like to have the the option to use the original version too with more kernels.

Compilation of pybind11 drivers fails with compiler crash on Ubuntu 18.04 with gcc 7.5.0

I am trying to install celerite2 using pip (version=20.2.3) with python3.7 on an Ubuntu18.04 machine, however I get the following error when pip is trying to build wheels for celerite2:

ERROR: Failed building wheel for celerite2
Failed to build celerite2
ERROR: Could not build wheels for celerite2 which use PEP 517 and cannot be installed directly

Thanks,
Tom

Derive associative scan algorithm for factorization

I've derived the algorithms for matrix multiplication and solves, but I haven't been able to work out the factorization algorithm yet. There don't seem to be numerical issues for the ops that I've derived so far, but I haven't extensively tested it. This would be interesting because it would allow parallel implementation on a GPU.

gp does not own its data

Hi,

I am using emcee to maximize a likelihood function to which I pass a celerite2.GaussianProcess object. I store the chain in a h5 file. When I start the emcee sampling from scratch there is no issue, but if I restart the sampling loading the h5 file then the gp.compute() instruction throws this error:

File ".../python3.6/site-packages/celerite2/core.py", line 313, in compute self._t, self._diag, c=self._c, a=self._a, U=self._U, V=self._V File ".../python3.6/site-packages/celerite2/terms.py", line 157, in get_celerite_matrices c.resize(J, refcheck=False) ValueError: cannot resize this array: it does not own its data

How can I fix it? Thanks

add support for variable coefficients in celerite models

I'm interested in being able to model different subsets of the data with different kernels, i.e. K = K1ย + K2, where K2 could have zeros everywhere except blocks corresponding to the target data subset. As you mentioned this isn't trivial with celerite, but could work in principle, i.e. by generalizing the kernel to be k(tau=|ti-tj|) = a_1(ti) * a_2(tj) * exp(-ctau) * cos(dtau) + b_1(ti) * b_2(tj) * exp(-ctau) * sin(dtau)

Theano Version Issues Between Exoplanet 0.4.4 and Celerite2

Describe the bug
The exoplanet software, when conda installed, automatically adds the package theano-pymc, which seems to conflict with my installed version of theano. When I have, without exoplanet, theano version 1.0.5 installed on my computer, the command

from celerite2.theano import terms

executes normally. After a conda installation of exoplanet version 0.4.4, theano-pymc version 1.1.5 is automatically installed. Then, the command

from celerite2.theano import terms

ends in an attribute error, I have attached a picture. This is all done on a freshly created conda environment, with minimal packages installed. My advisor has also been able to recreate this problem on a completely separate machine.

To Reproduce

It is not necessary to create a fresh environment, but that has the best hopes for reproduction:

conda create -n myenv python

Then install any files one needs, I prefer using Jupyter Notebook, and so I run

pip install notebook

Install celerite2 and theano:

conda install -c conda-forge theano python -m pip install -U celerite2

Then running the following command in either a jupyter notebook or any short python script should terminate successfully:

from celerite2.theano import terms

Next, installing any version of exoplanet seems to override this. I did it with the latest version:

conda install -c conda-forge exoplanet

Now when I run the command

from celerite2.theano import terms

The code ends in an Attribute error, a picture of which is attached.

exoplanet_bug

Expected behavior
I expect the command

from celerite2.theano import terms

To execute without error.

Your setup (please complete the following information):

  • Version of exoplanet: 0.4.4 (though I have tried multiple versions)
  • Operating system: macOS Catalina 10.15.7
  • Python version & installation method (pip, conda, etc.): Python version 3.7.10 (though I have tried multiple version)

Additional context
With this setup, I have installed and successfully run Celerite2 tutorials, provided they do not use the directory celerite.theano.

Additionally, I have run exactly copied exoplanet tutorials. Those that do not use celerite2 run to accurate completion. Those that use celerite2 end in failure.

Getting started tutorial fails with Theano

Hi,

I'm following your tutorial, and in the pymc3 section the code fails with this error:

Exception: ('Compilation failed (return status=254): clang-10: error: unable to execute command: Segmentation fault: 11. clang-10: error: dsymutil command failed due to signal (use -v to see invocation). ', '[Elemwise{true_div,no_inplace}(<TensorType(float64, row)>, <TensorType(float64, matrix)>)]')

I've not found how to fix it and it seems it's related to Theano.

I'm using python 3.7 and OS X Mojave 10.14.6

Do you have any idea how to fix this??

Thanks.

Add op for general (rectangular) matrix multiplication

Here's an example:

import numpy as np
import celerite2

def dot_tril(t1, t2, U, V, c, Y):
    N, J = U.shape
    M = V.shape[0]
    Nrhs = Y.shape[1]
    
    assert V.shape == (M, J)
    assert c.shape == (J,)
    assert Y.shape == (M, Nrhs)
    
    Z = np.zeros((N, Nrhs))
    Fn = np.zeros((J, Nrhs))

    n = 0
    m = 0
    while n < N:
        if m < M and t2[m] <= t1[n]:
            if m > 0:
                Fn = np.exp(-c * (t2[m] - t2[m - 1]))[:, None] * Fn
            Fn += np.outer(V[m], Y[m])
            m += 1
        else:
            if m > 0:
                Z[n] = np.dot(U[n], np.exp(-c * (t1[n] - t2[m - 1]))[:, None] * Fn)
            n += 1
            
    return Z


def dot_triu(t1, t2, U, V, c, Y):
    N, J = U.shape
    M = V.shape[0]
    Nrhs = Y.shape[1]
    
    assert V.shape == (M, J)
    assert c.shape == (J,)
    assert Y.shape == (M, Nrhs)
    
    Z = np.zeros((N, Nrhs))
    Fn = np.zeros((J, Nrhs))

    n = N - 1
    m = M - 1
    while n >= 0:
        if m >= 0 and t2[m] > t1[n]:
            if m < M - 1:
                Fn = np.exp(-c * (t2[m + 1] - t2[m]))[:, None] * Fn
            Fn += np.outer(V[m], Y[m])
            m -= 1
        else:
            if m < M - 1:
                Z[n] = np.dot(U[n], np.exp(-c * (t2[m + 1] - t1[n]))[:, None] * Fn)
            n -= 1
            
    return Z

np.random.seed(42)
t = np.sort(np.random.uniform(0, 3, 1600))
x = np.sort(np.random.uniform(-1, 5, 3049))

term = celerite2.terms.....
a1, U1, V1, P1 = term.get_celerite_matrices(t, np.zeros_like(t))
a2, U2, V2, P2 = term.get_celerite_matrices(x, np.zeros_like(x))
_, cr, _, _, cc, _ = term.get_coefficients()
c = np.concatenate((cr, cc, cc))

y = np.sin(x)[:, None]

inds = np.searchsorted(t, x)
lower = np.arange(len(t))[:, None] >= inds[None, :]
A = np.dot(np.exp(-c * t[:, None]) * U1, (np.exp(c * x[:, None]) * V2).T)
L = np.zeros_like(A)
L[lower] = A[lower]
assert np.allclose(dot_tril(t, x, U1, V2, c, y), np.dot(L, y))

upper = np.arange(len(t))[:, None] < inds[None, :]
A = np.dot(np.exp(c * t[:, None]) * V1, (np.exp(-c * x[:, None]) * U2).T)
U = np.zeros_like(A)
U[upper] = A[upper]
assert np.allclose(dot_triu(t, x, V1, U2, c, y), np.dot(U, y))

Online documentation notebook tutorial rendering issue: "findfont: Generic family 'sans-serif' not found..."

Hi there ๐Ÿ‘‹
I noticed the online documentation has a warning about missing font families:
findfont: Generic family 'sans-serif' not found because none of the following families were found: Liberation Sans
I see this on Safari and Chrome, so it's not a browser thing, it appears to be embedded in the rendered docs?

The warning repeats 127 times, filling the screen and requiring long scrolls to navigate to the next series of code cells after a cell involving a plot.

I'm not sure of the cause. One could imagine the font server coincidentally had downtime the last time the docs were rendered, in which case simply triggering a re-render of RTD could fix the issue. Probably more complicated than that. ๐Ÿคท

Add comment that get_coefficients requires a tuple of arrays as output to the docs

Here is a minimal code block to reproduce the error

from celerite2.terms import Term
class CosineTerm(Term):
    def __init__(self, omega_j, sigma_j):
        self.omega_j = omega_j
        self.sigma_j = sigma_j
    def get_coefficients(self):
        return 0, 0, self.sigma_j, 0, 0, self.omega_j
kernel =  CosineTerm(1,0.1) + CosineTerm(2,0.2)
kernel.get_coefficients()

The following error is thrown at line 161 of terms.py

ValueError: zero-dimensional arrays cannot be concatenated

Damped Random Walk Model?

This is likely a lame question, but I will shoot.

Which celerite/celerite2 kernel could be used as Damped Random Walk Model? This is a model used for analyzing quasar light curves. My understanding is that this is a GP model, but I can't find such a kernel, and at this point I do not fell confident enough to build it in celerite/celerite2. So what are my alternatives?

Work out traceable JVP and transpose rules for JAX

It should be possible to write the JVP ops using existing celerite primitives. This would allow support for higher order differentiation and perhaps it won't cause a significant computational overhead.

For example, the matmul_lower JVP can be implemented as follows:

def matmul_lower_jvp(arg_values, arg_tangents):
    def make_zero(x, t):
        return lax.zeros_like_array(x) if type(t) is ad.Zero else t

    t, c, U, V, Y = arg_values
    tp, cp, Up, Vp, Yp = (
        make_zero(x, t) for x, t in zip(arg_values, arg_tangents)
    )
    
    Ut = -(c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * U + Up
    Vt = (c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * V + Vp
    Zp = matmul_lower(t, c, U, V, Yp)
    Zp += matmul_lower(t, c, Ut, V, Y)
    Zp += matmul_lower(t, c, U, Vt, Y)
    
    return matmul_lower_p.bind(t, c, U, V, Y), (Zp, None)

But I haven't figured out the correct transpose yet.

Uncomment `pytensor.config.compute_test_value = "raise" in `conftest.py`

Due to a bug in PyMC (pymc-devs/pymc#6981), setting the option mentioned in the title caused an error in test_marginal for PyMC after updating to v5 (#91).
A PR fixing this (pymc-devs/pymc#6982) was merged into PyMC, but not yet released. The code below can be uncommented when the fix is released.

# TODO: Uncomment when PyMC commit 714b4a0 makes it into a release (probably 5.9.2)
# pytensor.config.compute_test_value = "raise"

Design interface for missing data in Kronecker models

@tagordon, @ericagol: I have a proposal.

I expect that a common use case would be something like Rubin where you won't ever have multiple bands observed simultaneously, therefore there's a big overhead introduced by building the full matrices and then masking. So:

  • For the low rank terms, I think that this could be most efficiently implemented by just allowing the model to have variable kernel amplitudes. I would implement this by allowing some NxJ matrix Alpha which you would multiply into U and V and then square, sum along the 2nd axis and then multiply into a. I think that this would be equivalent to the low rank Kronecker model with missing data.
  • For the dense version, I think that things will be a bit trickier and I'm not sure what the best interface is. I think it would be worth working this through carefully and honestly I think that it might be worth writing a paper. It looks to me like we might be able to come up with a pretty efficient algorithm for this and we'd probably have a lot of users!

LinAlgError: failed to factorize or solve matrix

celerite2 version: 0.1.1.dev4+gcee9a0f
theano: 1.0.11
python: 3.6.7

Description

I keep running into a LinAlgError that I can't figure out. I was expecting to be able to sample a Gaussian process based on the exoplanet example (https://gallery.exoplanet.codes/en/latest/tutorials/quick-tess/) and modified it a little bit, but couldn't get it to run with a Matern nor SHOTerm kernel.

What I did

Here's the most minimal working example I could make.

import pymc3 as pm
import numpy as np
from celerite2.theano import terms, GaussianProcess

x_JD = np.array([2458899.75295227, 2458899.75992418, 2458899.76689609,
                 2458899.773868  , 2458899.78083991, 2458899.78781182,
                 2458899.79478373, 2458899.80175564, 2458899.80872755,
                 2458899.81569946, 2458899.82267136, 2458899.82964327,
                 2458899.83661518, 2458899.84358709, 2458899.850559  ,
                 2458899.85753091, 2458899.86450282])
x = x_JD - np.median(x_JD)
y = np.array([ 999.91776623,  999.70908473,  999.41758557, 1000.42187202,
              1000.33801407,  999.57517094,  999.72223547, 1000.19105569,
              1000.27274928, 1000.78978838, 1000.08780841, 1000.29321361,
              1000.17255564, 1000.40390643,  999.97807388, 1000.60678241,
               999.35199641])

with pm.Model() as model:
    
    resid = y
    
    sigma_lc = pm.Lognormal("sigma_lc", mu=-1, sigma=3)
    rho_gp = pm.Lognormal("rho_gp", mu=0, sd=10)
    sigma_gp = pm.Lognormal("sigma_gp", mu=-1, sigma=3)
    
    kernel = terms.Matern32Term(sigma=sigma_gp,rho=rho_gp)
    
    gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
    gp.marginal("gp", observed=resid)
    pm.Deterministic("gp_pred", gp.predict(resid))
        
    trace = pm.sample(tune=3000, 
                      draws=3000, 
                      start=model.test_point, 
                      cores=2, 
                      chains=2, 
                      init="adapt_full", 
                      target_accept=0.9)
Click for the full traceback
/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py:468: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  FutureWarning,
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_full...
/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/quadpotential.py:514: UserWarning: QuadPotentialFullAdapt is an experimental feature
  warnings.warn("QuadPotentialFullAdapt is an experimental feature")
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sigma_gp, rho_gp, sigma_lc]
 1.22% [146/12000 00:03<04:47 Sampling 2 chains, 0 divergences]
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 974, in __call__
    if output_subset is None
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 913, in rval
    r = p(n, [x[0] for x in i], o)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 120, in perform
    func(*args)
celerite2.backprop.LinAlgError: failed to factorize or solve matrix

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 138, in run
    self._start_loop()
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 192, in _start_loop
    point, stats = self._compute_point()
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 217, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 273, in step
    apoint, stats = self.astep(array)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 169, in astep
    hmc_step = self._hamiltonian_step(start, p0, step_size)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 185, in _hamiltonian_step
    divergence_info, turning = tree.extend(direction)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 276, in extend
    self.right, self.depth, floatX(np.asarray(self.step_size))
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 369, in _build_subtree
    tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 369, in _build_subtree
    tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 365, in _build_subtree
    tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 363, in _build_subtree
    return self._single_step(left, epsilon)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/nuts.py", line 325, in _single_step
    right = self.integrator.step(epsilon, left)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/integration.py", line 69, in step
    return self._step(epsilon, state)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/integration.py", line 102, in _step
    logp = self._logp_dlogp_func(q_new, q_new_grad)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/model.py", line 733, in __call__
    output = self._theano_function(array)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 989, in __call__
    storage_map=getattr(self.fn, "storage_map", None),
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/link.py", line 343, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/compile/function/types.py", line 974, in __call__
    if output_subset is None
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 913, in rval
    r = p(n, [x[0] for x in i], o)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 120, in perform
    func(*args)
celerite2.backprop.LinAlgError: failed to factorize or solve matrix
Apply node that caused the error: _CeleriteOp{name='factor_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]})
Toposort index: 34
Inputs types: [TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(17,), (2,), (17,), (17, 2), (17, 2)]
Inputs strides: [(8,), (8,), (8,), (16, 8), (16, 8)]
Inputs values: ['not shown', array([5.8092922e-06, 5.8092922e-06]), 'not shown', 'not shown', 'not shown']
Outputs clients: [[_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3), Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}(TensorConstant{(1,) of -0.5}, _CeleriteOp{name='factor_fwd', quiet=False}.0, TensorConstant{(1,) of 0.5}, Elemwise{Sqr}[(0, 0)].0), Elemwise{Composite{((-i0) / i1)}}(InplaceDimShuffle{0}.0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{TrueDiv}[(0, 0)](Elemwise{Sqr}[(0, 0)].0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{Log}[(0, 0)](_CeleriteOp{name='factor_fwd', quiet=False}.0)], [_CeleriteOp{name='solve_lower_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}), _CeleriteOp{name='solve_lower_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}, _CeleriteOp{name='solve_lower_fwd', quiet=False}.0, _CeleriteOp{name='solve_lower_fwd', quiet=False}.1, IncSubtensor{InplaceInc;::, int64}.0), _CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)], [_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)]]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-85507f45a7d6>", line 28, in <module>
    gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 210, in __init__
    self.compute(t, **kwargs)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 316, in compute
    self._do_compute(quiet)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/celerite2.py", line 84, in _do_compute
    self._t, self._c, self._a, self._U, self._V
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 642, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in make_node
    for spec in self.spec["outputs"] + self.spec["extra_outputs"]
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in <listcomp>
    for spec in self.spec["outputs"] + self.spec["extra_outputs"]

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.
"""

The above exception was the direct cause of the following exception:

LinAlgError                               Traceback (most recent call last)
LinAlgError: failed to factorize or solve matrix
Apply node that caused the error: _CeleriteOp{name='factor_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]})
Toposort index: 34
Inputs types: [TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, vector), TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(17,), (2,), (17,), (17, 2), (17, 2)]
Inputs strides: [(8,), (8,), (8,), (16, 8), (16, 8)]
Inputs values: ['not shown', array([5.8092922e-06, 5.8092922e-06]), 'not shown', 'not shown', 'not shown']
Outputs clients: [[_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3), Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}(TensorConstant{(1,) of -0.5}, _CeleriteOp{name='factor_fwd', quiet=False}.0, TensorConstant{(1,) of 0.5}, Elemwise{Sqr}[(0, 0)].0), Elemwise{Composite{((-i0) / i1)}}(InplaceDimShuffle{0}.0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{TrueDiv}[(0, 0)](Elemwise{Sqr}[(0, 0)].0, _CeleriteOp{name='factor_fwd', quiet=False}.0), Elemwise{Log}[(0, 0)](_CeleriteOp{name='factor_fwd', quiet=False}.0)], [_CeleriteOp{name='solve_lower_fwd', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}), _CeleriteOp{name='solve_lower_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Join.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, TensorConstant{[[ 999.917..35199641]]}, _CeleriteOp{name='solve_lower_fwd', quiet=False}.0, _CeleriteOp{name='solve_lower_fwd', quiet=False}.1, IncSubtensor{InplaceInc;::, int64}.0), _CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)], [_CeleriteOp{name='factor_rev', quiet=False}(TensorConstant{[-0.055775...05577527]}, Join.0, Alloc.0, Join.0, TensorConstant{[[ 9.99999..2674e-04]]}, _CeleriteOp{name='factor_fwd', quiet=False}.0, _CeleriteOp{name='factor_fwd', quiet=False}.1, _CeleriteOp{name='factor_fwd', quiet=False}.2, Elemwise{Composite{((i0 / i1) + ((i2 * i3) / sqr(i1)))}}.0, _CeleriteOp{name='solve_lower_rev', quiet=False}.3)]]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-85507f45a7d6>", line 28, in <module>
    gp = GaussianProcess(kernel, t=x, yerr=sigma_lc)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 210, in __init__
    self.compute(t, **kwargs)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/core.py", line 316, in compute
    self._do_compute(quiet)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/celerite2.py", line 84, in _do_compute
    self._t, self._c, self._a, self._U, self._V
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/theano/gof/op.py", line 642, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in make_node
    for spec in self.spec["outputs"] + self.spec["extra_outputs"]
  File "/Users/eas342/anaconda/envs/py36/lib/python3.6/site-packages/celerite2/theano/ops.py", line 85, in <listcomp>
    for spec in self.spec["outputs"] + self.spec["extra_outputs"]

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-1-85507f45a7d6> in <module>
     36                       chains=2,
     37                       init="adapt_full",
---> 38                       target_accept=0.9)

~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    555         _print_step_hierarchy(step)
    556         try:
--> 557             trace = _mp_sample(**sample_args, **parallel_args)
    558         except pickle.PickleError:
    559             _log.warning("Could not pickle model, sampling singlethreaded.")

~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1474         try:
   1475             with sampler:
-> 1476                 for draw in sampler:
   1477                     trace = traces[draw.chain - chain]
   1478                     if trace.supports_sampler_stats and draw.stats is not None:

~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    478 
    479         while self._active:
--> 480             draw = ProcessAdapter.recv_draw(self._active)
    481             proc, is_last, draw, tuning, stats, warns = draw
    482             self._total_draws += 1

~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    358             else:
    359                 error = RuntimeError("Chain %s failed." % proc.chain)
--> 360             raise error from old_error
    361         elif msg[0] == "writing_done":
    362             proc._readable = True

RuntimeError: Chain 0 failed.

Help with JAX, PyTorch, and TensorFlow bindings

I don't really use these libraries (yet!) so I'm not sure exactly what the use cases would look like. Any help would be appreciated.

  • JAX This interface is probably the most mature, but I'd love to add XLA primitives (#1). This would involve working out a traceable forward diff op, I think. More trouble than its worth? Maybe not.
  • PyTorch The ops all work and are tested, but I don't really understand how nn.Modules work so it's probably all a mess.
  • TensorFlow Lives in this repo. It doesn't follow the recommended build process, but I think I know where to go with this one. Long run: maybe it can be combined with JAX if we work out the XLA primitive stuff.

GaussianProcess docstring would benefit from enhancements

There is an inconsistency between the celerite2 Getting Started tutorial and the Exoplanet tutorials in the use of GaussianProcess and compute. Specifically, in the celerite2 tutorial compute is called explicitly but in the Exoplanet tutorials that include celerite2, compute is implicitly called through supplying the t keyword.

I don't think this is a particular problem, but it would be good if the main GaussianProcess docstring stated that if t is supplied then the compute method is automatically run.

In addition, the theano.GaussianProcess class doesn't have a docstring so this would benefit from having the same docstring as the main GaussianProcess class, also with the addition of the note about compute.

jax.jit decorator breaks example

Hi @dfm,

Quick question: it seems like you can't jit compile the example function numpyro_model in this tutorial because of the way celerite2 is written. I'm wondering if this is intentional or if I'm misunderstanding something about numpyro models. Should numpyro_model be jit-able? Thanks!

Very long traceback below:

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-5-8f42ec9b0711> in <module>
     69 rng_key = random.PRNGKey(34923)
---> 70 mcmc.run(rng_key, t, yerr, y=y)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    501             if self.chain_method == 'sequential':
--> 502                 states, last_state = _laxmap(partial_map_fn, map_args)
    503             elif self.chain_method == 'parallel':

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
    158         x = jit(_get_value_from_index)(xs, i)
--> 159         ys.append(f(x))
    160 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    342         collection_size = collection_size if collection_size is None else collection_size // self.thinning
--> 343         collect_vals = fori_collect(lower_idx,
    344                                     upper_idx,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    310                 for i in t:
--> 311                     vals = jit(_body_fn)(i, vals)
    312                     t.set_description(progbar_desc(i), refresh=False)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in _body_fn(i, vals)
    283         val, collection, start_idx, thinning = vals
--> 284         val = body_fun(val)
    285         idx = (i - start_idx) // thinning

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
    170     # state is a tuple of size 1 - containing HMCState
--> 171     return sampler.sample(state[0], args, kwargs),
    172 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
    529         """
--> 530         return self._sample_fn(state, model_args, model_kwargs)
    531 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
    322         vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
--> 323         vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
    324                                                                     hmc_state.adapt_state.inverse_mass_matrix,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length)
    291 
--> 292         binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
    293                                  inverse_mass_matrix, step_size, rng_key,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
    748     state = (tree, rng_key)
--> 749     tree, _ = while_loop(_cond_fn, _body_fn, state)
    750     return tree

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    122     else:
--> 123         return lax.while_loop(cond_fun, body_fun, init_val)
    124 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
    742         going_right = random.bernoulli(direction_key)
--> 743         tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
    744                             going_right, doubling_key, energy_current, max_delta_energy,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    600 
--> 601     new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
    602                                         inverse_mass_matrix, step_size,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    687 
--> 688     tree, turning, _, _, _ = while_loop(
    689         _cond_fn,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    122     else:
--> 123         return lax.while_loop(cond_fun, body_fun, init_val)
    124 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
    659         z, r, z_grad = _get_leaf(current_tree, going_right)
--> 660         new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
    661                                    going_right, energy_current, max_delta_energy)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
    568     step_size = jnp.where(going_right, step_size, -step_size)
--> 569     z_new, r_new, potential_energy_new, z_new_grad = vv_update(
    570         step_size,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in update_fn(step_size, inverse_mass_matrix, state)
    229         z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
--> 230         potential_energy, z_grad = _value_and_grad(potential_fn, z, forward_mode_differentiation)
    231         r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
    183     else:
--> 184         return value_and_grad(f)(x)
    185 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    162     # no param is needed for log_density computation because we already substitute
--> 163     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    164     return - log_joint

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
     58             else:
---> 59                 log_prob = site['fn'].log_prob(value)
     60 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/distributions/util.py in wrapper(self, *args, **kwargs)
    554     def wrapper(self, *args, **kwargs):
--> 555         log_prob = log_prob_fn(self, *args, *kwargs)
    556         if self._validate_args:

//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/jax/distribution.py in log_prob(self, value)
     33         def log_prob(self, value):
---> 34             return self.gp.log_likelihood(value)
     35 

//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/core.py in log_likelihood(self, y, inplace)
    426         y = self._process_input(y, require_vector=True, inplace=inplace)
--> 427         return self._norm - 0.5 * self._do_norm(y - self._mean_value)
    428 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other)
   5255       return NotImplemented
-> 5256     return binary_op(self, other)
   5257   return deferring_binary_op

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2)
    386   else:
--> 387     fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
    388   if lax_doc:

FilteredStackTrace: jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Incompatible sublevel: DynamicJaxprTrace(level=2/2), (4, 1).
The tracer that caused this error was created on line //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__).
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/handlers.py:162 (get_trace)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
The function being traced when the tracer leaked was _body_fn at //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py:655.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

UnexpectedTracerError                     Traceback (most recent call last)
<ipython-input-5-8f42ec9b0711> in <module>
     68 mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=2)
     69 rng_key = random.PRNGKey(34923)
---> 70 mcmc.run(rng_key, t, yerr, y=y)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    500         else:
    501             if self.chain_method == 'sequential':
--> 502                 states, last_state = _laxmap(partial_map_fn, map_args)
    503             elif self.chain_method == 'parallel':
    504                 states, last_state = pmap(partial_map_fn)(map_args)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
    157     for i in range(n):
    158         x = jit(_get_value_from_index)(xs, i)
--> 159         ys.append(f(x))
    160 
    161     return tree_multimap(lambda *args: jnp.stack(args), *ys)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    341         collection_size = self._collection_params["collection_size"]
    342         collection_size = collection_size if collection_size is None else collection_size // self.thinning
--> 343         collect_vals = fori_collect(lower_idx,
    344                                     upper_idx,
    345                                     sample_fn,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    309             with tqdm.trange(upper) as t:
    310                 for i in t:
--> 311                     vals = jit(_body_fn)(i, vals)
    312                     t.set_description(progbar_desc(i), refresh=False)
    313                     if diagnostics_fn:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    414       return cpp_jitted_f(*args, **kwargs)
    415     else:
--> 416       return cpp_jitted_f(context, *args, **kwargs)
    417   f_jitted._cpp_jitted_f = cpp_jitted_f
    418 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in cache_miss(_, *args, **kwargs)
    295       _check_arg(arg)
    296     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 297     out_flat = xla.xla_call(
    298         flat_fun,
    299         *args_flat,

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1392 
   1393   def bind(self, fun, *args, **params):
-> 1394     return call_bind(self, fun, *args, **params)
   1395 
   1396   def process(self, trace, fun, tracers, params):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1383   tracers = map(top_trace.full_raise, args)
   1384   with maybe_new_sublevel(top_trace):
-> 1385     outs = primitive.process(top_trace, fun, tracers, params)
   1386   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1387 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1395 
   1396   def process(self, trace, fun, tracers, params):
-> 1397     return trace.process_call(self, fun, tracers, params)
   1398 
   1399   def post_process(self, trace, out_tracers, params):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    623 
    624   def process_call(self, primitive, f, tracers, params):
--> 625     return primitive.impl(f, *tracers, **params)
    626   process_map = process_call
    627 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    584 
    585 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 586   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    587                                *unsafe_map(arg_spec, args))
    588   try:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    258       fun.populate_stores(stores)
    259     else:
--> 260       ans = call(fun, *args)
    261       cache[key] = (ans, fun.stores)
    262 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    660   abstract_args, arg_devices = unzip2(arg_specs)
    661   if config.omnistaging_enabled:
--> 662     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    663     if any(isinstance(c, core.Tracer) for c in consts):
    664       raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1218     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1219     main.jaxpr_stack = ()  # type: ignore
-> 1220     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1221     del fun, main
   1222   return jaxpr, out_avals, consts

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1198     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1199     in_tracers = map(trace.new_arg, in_avals)
-> 1200     ans = fun.call_wrapped(*in_tracers)
   1201     out_tracers = map(trace.full_raise, ans)
   1202     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in _body_fn(i, vals)
    282     def _body_fn(i, vals):
    283         val, collection, start_idx, thinning = vals
--> 284         val = body_fun(val)
    285         idx = (i - start_idx) // thinning
    286         collection = cond(idx >= 0,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
    169 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    170     # state is a tuple of size 1 - containing HMCState
--> 171     return sampler.sample(state[0], args, kwargs),
    172 
    173 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
    528         :return: Next `state` after running HMC.
    529         """
--> 530         return self._sample_fn(state, model_args, model_kwargs)
    531 
    532 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
    321             if hmc_state.r is None else hmc_state.r
    322         vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
--> 323         vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
    324                                                                     hmc_state.adapt_state.inverse_mass_matrix,
    325                                                                     vv_state,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length)
    290             _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
    291 
--> 292         binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
    293                                  inverse_mass_matrix, step_size, rng_key,
    294                                  max_delta_energy=max_delta_energy,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
    747 
    748     state = (tree, rng_key)
--> 749     tree, _ = while_loop(_cond_fn, _body_fn, state)
    750     return tree
    751 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    121         return val
    122     else:
--> 123         return lax.while_loop(cond_fun, body_fun, init_val)
    124 
    125 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in while_loop(cond_fun, body_fun, init_val)
    282   # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
    283   # necessary, a second time with modified init values.
--> 284   init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
    285   new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
    286   if changed:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _create_jaxpr(init_val)
    268     init_avals = tuple(_map(_abstractify, init_vals))
    269     cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 270     body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
    271     if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
    272       msg = "cond_fun must return a boolean scalar, but got pytree {}."

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
    196         return f(*args, **kwargs)
    197       else:
--> 198         return cached(bool(config.x64_enabled), *args, **kwargs)
    199 
    200     wrapper.cache_clear = cached.cache_clear

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
    189     @functools.lru_cache(max_size)
    190     def cached(_, *args, **kwargs):
--> 191       return f(*args, **kwargs)
    192 
    193     @functools.wraps(f)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
     71 @cache()
     72 def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
---> 73   jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
     74   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
     75   return closed_jaxpr, consts, out_tree

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
    196         return f(*args, **kwargs)
    197       else:
--> 198         return cached(bool(config.x64_enabled), *args, **kwargs)
    199 
    200     wrapper.cache_clear = cached.cache_clear

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
    189     @functools.lru_cache(max_size)
    190     def cached(_, *args, **kwargs):
--> 191       return f(*args, **kwargs)
    192 
    193     @functools.wraps(f)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_open_jaxpr(fun, in_tree, in_avals)
     66 def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
     67   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
---> 68   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
     69   return jaxpr, consts, out_tree()
     70 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_dynamic(fun, in_avals)
   1188     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1189     main.jaxpr_stack = ()  # type: ignore
-> 1190     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1191     del main, fun
   1192   return jaxpr, out_avals, consts

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1198     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1199     in_tracers = map(trace.new_arg, in_avals)
-> 1200     ans = fun.call_wrapped(*in_tracers)
   1201     out_tracers = map(trace.full_raise, ans)
   1202     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
    741         key, direction_key, doubling_key = random.split(key, 3)
    742         going_right = random.bernoulli(direction_key)
--> 743         tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
    744                             going_right, doubling_key, energy_current, max_delta_energy,
    745                             r_ckpts, r_sum_ckpts)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    599     key, transition_key = random.split(rng_key)
    600 
--> 601     new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
    602                                         inverse_mass_matrix, step_size,
    603                                         going_right, key, energy_current, max_delta_energy,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    686     basetree = prototype_tree._replace(num_proposals=0)
    687 
--> 688     tree, turning, _, _, _ = while_loop(
    689         _cond_fn,
    690         _body_fn,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    121         return val
    122     else:
--> 123         return lax.while_loop(cond_fun, body_fun, init_val)
    124 
    125 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in while_loop(cond_fun, body_fun, init_val)
    282   # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
    283   # necessary, a second time with modified init values.
--> 284   init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
    285   new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
    286   if changed:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _create_jaxpr(init_val)
    268     init_avals = tuple(_map(_abstractify, init_vals))
    269     cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 270     body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
    271     if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
    272       msg = "cond_fun must return a boolean scalar, but got pytree {}."

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
    196         return f(*args, **kwargs)
    197       else:
--> 198         return cached(bool(config.x64_enabled), *args, **kwargs)
    199 
    200     wrapper.cache_clear = cached.cache_clear

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
    189     @functools.lru_cache(max_size)
    190     def cached(_, *args, **kwargs):
--> 191       return f(*args, **kwargs)
    192 
    193     @functools.wraps(f)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
     71 @cache()
     72 def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
---> 73   jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
     74   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
     75   return closed_jaxpr, consts, out_tree

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in wrapper(*args, **kwargs)
    196         return f(*args, **kwargs)
    197       else:
--> 198         return cached(bool(config.x64_enabled), *args, **kwargs)
    199 
    200     wrapper.cache_clear = cached.cache_clear

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in cached(_, *args, **kwargs)
    189     @functools.lru_cache(max_size)
    190     def cached(_, *args, **kwargs):
--> 191       return f(*args, **kwargs)
    192 
    193     @functools.wraps(f)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/control_flow.py in _initial_style_open_jaxpr(fun, in_tree, in_avals)
     66 def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
     67   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
---> 68   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
     69   return jaxpr, consts, out_tree()
     70 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_dynamic(fun, in_avals)
   1188     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1189     main.jaxpr_stack = ()  # type: ignore
-> 1190     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1191     del main, fun
   1192   return jaxpr, out_avals, consts

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1198     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1199     in_tracers = map(trace.new_arg, in_avals)
-> 1200     ans = fun.call_wrapped(*in_tracers)
   1201     out_tracers = map(trace.full_raise, ans)
   1202     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
    658         # If we are going to the right, start from the right leaf of the current tree.
    659         z, r, z_grad = _get_leaf(current_tree, going_right)
--> 660         new_leaf = _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size,
    661                                    going_right, energy_current, max_delta_energy)
    662         new_tree = cond(current_tree.num_proposals == 0,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
    567                     energy_current, max_delta_energy):
    568     step_size = jnp.where(going_right, step_size, -step_size)
--> 569     z_new, r_new, potential_energy_new, z_new_grad = vv_update(
    570         step_size,
    571         inverse_mass_matrix,

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in update_fn(step_size, inverse_mass_matrix, state)
    228         r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r)
    229         z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
--> 230         potential_energy, z_grad = _value_and_grad(potential_fn, z, forward_mode_differentiation)
    231         r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1)
    232         return IntegratorState(z, r, potential_energy, z_grad)

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
    182         return f(x), jacfwd(f)(x)
    183     else:
--> 184         return value_and_grad(f)(x)
    185 
    186 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    821     tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
    822     if not has_aux:
--> 823       ans, vjp_py = _vjp(f_partial, *dyn_args)
    824     else:
    825       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/api.py in _vjp(fun, has_aux, *primals)
   1894   if not has_aux:
   1895     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1896     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1897     out_tree = out_tree()
   1898   else:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    112 def vjp(traceable, primals, has_aux=False):
    113   if not has_aux:
--> 114     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    115   else:
    116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     99   _, in_tree = tree_flatten(((primals, primals), {}))
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
    103   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    504   with core.new_main(JaxprTrace) as main:
    505     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 506     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    507     assert not env
    508     del main, fun, env

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    161     substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
    162     # no param is needed for log_density computation because we already substitute
--> 163     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    164     return - log_joint
    165 

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
     57                 log_prob = site['fn'].log_prob(value, intermediates)
     58             else:
---> 59                 log_prob = site['fn'].log_prob(value)
     60 
     61             if (scale is not None) and (not is_identically_one(scale)):

//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/distributions/util.py in wrapper(self, *args, **kwargs)
    553 def validate_sample(log_prob_fn):
    554     def wrapper(self, *args, **kwargs):
--> 555         log_prob = log_prob_fn(self, *args, *kwargs)
    556         if self._validate_args:
    557             value = kwargs['value'] if 'value' in kwargs else args[0]

//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/jax/distribution.py in log_prob(self, value)
     32         @dist.util.validate_sample
     33         def log_prob(self, value):
---> 34             return self.gp.log_likelihood(value)
     35 
     36         def sample(self, key, sample_shape=()):

//anaconda3/envs/pymc/lib/python3.8/site-packages/celerite2/core.py in log_likelihood(self, y, inplace)
    425         """
    426         y = self._process_input(y, require_vector=True, inplace=inplace)
--> 427         return self._norm - 0.5 * self._do_norm(y - self._mean_value)
    428 
    429     def predict(

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in __sub__(self, other)
    521   def __add__(self, other): return self.aval._add(self, other)
    522   def __radd__(self, other): return self.aval._radd(self, other)
--> 523   def __sub__(self, other): return self.aval._sub(self, other)
    524   def __rsub__(self, other): return self.aval._rsub(self, other)
    525   def __mul__(self, other): return self.aval._mul(self, other)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other)
   5254     if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
   5255       return NotImplemented
-> 5256     return binary_op(self, other)
   5257   return deferring_binary_op
   5258 

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x1, x2)
    385     fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
    386   else:
--> 387     fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
    388   if lax_doc:
    389     doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/lax/lax.py in sub(x, y)
    343 def sub(x: Array, y: Array) -> Array:
    344   r"""Elementwise subtraction: :math:`x - y`."""
--> 345   return sub_p.bind(x, y)
    346 
    347 def mul(x: Array, y: Array) -> Array:

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
    281                               or valid_jaxtype(arg) for arg in args), args
    282     top_trace = find_top_trace(args)
--> 283     tracers = map(top_trace.full_raise, args)
    284     out = top_trace.process_primitive(self, tracers, params)
    285     return map(full_lower, out) if self.multiple_results else full_lower(out)

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/_src/util.py in safe_map(f, *args)
     39   for arg in args[1:]:
     40     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 41   return list(map(f, *args))
     42 
     43 def unzip2(xys):

//anaconda3/envs/pymc/lib/python3.8/site-packages/jax/core.py in full_raise(self, val)
    402     elif val._trace.level < level:
    403       if val._trace.sublevel > sublevel:
--> 404         raise escaped_tracer_error(
    405             val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
    406       return self.lift(val)

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Incompatible sublevel: DynamicJaxprTrace(level=2/2), (4, 1).
The tracer that caused this error was created on line //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__).
When the tracer was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/handlers.py:162 (get_trace)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
//anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/primitives.py:80 (__call__)
The function being traced when the tracer leaked was _body_fn at //anaconda3/envs/pymc/lib/python3.8/site-packages/numpyro/infer/hmc_util.py:655.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
\```
</p>
</details>

Installing celerite2 with conda

Hi there! I encountered an issue with macOS when I attempted to run celerite2 in my jupyter notebook.

Here's the line:

from celerite2.theano import terms, GaussianProcess

And the issue was:

The kernel appears to have died. It will restart automatically.

I couldn't solve it by restarting my computer and my notebook. I then searched online and found one tip that might work, which requires me to install the environment with conda. (I used Anaconda but installed celerite2 with pip.) However, I didn't find a way to install it with conda in the tutorial. Would it be convenient for you to add a way with conda?

Thanks a lot!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.