Coder Social home page Coder Social logo

Comments (13)

fehiepsi avatar fehiepsi commented on August 17, 2024 1

Hi @nikisix and @kylejcaron, really sorry for the breakage! I think a good action is to introduce exclude_deterministic=True to Predictive. This rolls the behavior back to pre-0.14 release. I'm less worried that new users will want to use deterministic sites in Predictive. What do you think, @martinjankowiak?

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

Sorry for the breakage! Could you try to use the dev branch of lightweight mmm? I will ping a dev there for a release if it works.

from numpyro.

AkiroSR avatar AkiroSR commented on August 17, 2024

I think it's related to numpyro. The problem function is numpyro.deterministic.
Everything else works.
I'll have a look but I reckon it's related to the meridian release

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

Do you mean that pip install --upgrade git+https://github.com/google/lightweight_mmm.git does not resolve the issue?

from numpyro.

nikisix avatar nikisix commented on August 17, 2024

@fehiepsi saw your fix on lightweight Change-Id: I7c0658b0a13506c319fd3e6e00cdf2791d64e26f.

I believe the long-term fix here is 2-fold:

  1. Return deterministic sites in posterior_samples (mcmc saves deterministic sites in its samples, and accessed via mcmc.get_samples()).
  2. Predictive always pops deterministic sites.

If these are unfeasible for deeper reasons, then at least mention the pop trick here: https://num.pyro.ai/en/v0.2.0/utilities.html

As the current behavior is a bit counterintuitive.

from numpyro.

kylejcaron avatar kylejcaron commented on August 17, 2024

I'm running into the same issue, here's a reproducible example:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS,Predictive
from jax import random


X = np.random.normal(0, 1, size=1000)
y = 5 + 1.2*X + np.random.normal(size=1000)

def model(X,y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0,10))
    beta = numpyro.sample("beta", dist.Normal(0,1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    with numpyro.plate("data", len(X)):
        eta = numpyro.deterministic("eta", alpha + beta*X)
        obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)
   
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), X=X, y=y)

# Make predictions where X is a different shape
posterior_samples = mcmc.get_samples()
# posterior_samples.pop("eta") # this fixes the issues
pred_func = Predictive(model, posterior_samples=posterior_samples)
traceback

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:290, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    289 else:
--> 290   return cached(config.trace_context(), *args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:283, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    281 @functools.lru_cache(max_size)
    282 def cached(_, *args, **kwargs):
--> 283   return f(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
    153 @cache()
    154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155   return _broadcast_shapes_uncached(*shapes)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[1], line 26
     24 # Make predictions where X is a different shape
     25 pred_func = Predictive(model, posterior_samples=mcmc.get_samples())
---> 26 preds = pred_func(random.PRNGKey(1), X=X[:200], y=None)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:1011, in Predictive.__call__(self, rng_key, *args, **kwargs)
   1001 """
   1002 Returns dict of samples from the predictive distribution. By default, only sample sites not
   1003 contained in `posterior_samples` are returned. This can be modified by changing the
   (...)
   1008 :param kwargs: model kwargs.
   1009 """
   1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1011     return self._call_with_params(rng_key, self.params, args, kwargs)
   1012 elif self.batch_ndims == 1:  # batch over parameters
   1013     batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:988, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
    977     posterior_samples = _predictive(
    978         guide_rng_key,
    979         guide,
   (...)
    985         model_kwargs=kwargs,
    986     )
    987 model = substitute(self.model, self.params)
--> 988 return _predictive(
    989     rng_key,
    990     model,
    991     posterior_samples,
    992     self._batch_shape,
    993     return_sites=self.return_sites,
    994     infer_discrete=self.infer_discrete,
    995     parallel=self.parallel,
    996     model_args=args,
    997     model_kwargs=kwargs,
    998 )

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:825, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
    823 rng_key = rng_key.reshape(batch_shape + key_shape)
    824 chunk_size = num_samples if parallel else 1
--> 825 return soft_vmap(
    826     single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    827 )

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/util.py:419, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    413     xs = tree_map(
    414         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    415         xs,
    416     )
    417     fn = vmap(fn)
--> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    421 ys = tree_map(
    422     lambda y: jnp.reshape(
    423         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    424     )[:batch_size],
    425     ys,
    426 )

    [... skipping hidden 12 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:798, in _predictive.<locals>.single_prediction(val)
    789     pred_samples = _sample_posterior(
    790         config_enumerate(condition(model, samples)),
    791         first_available_dim,
   (...)
    795         **model_kwargs,
    796     )
    797 else:
--> 798     model_trace = trace(
    799         seed(substitute(masked_model, samples), rng_key)
    800     ).get_trace(*model_args, **model_kwargs)
    801     pred_samples = {name: site["value"] for name, site in model_trace.items()}
    803 if return_sites is not None:

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (2 times)]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[1], line 17, in model(X, y)
     15 with numpyro.plate("data", len(X)):
     16     eta = numpyro.deterministic("eta", alpha + beta*X)
---> 17     obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     "type": "sample",
    209     "name": name,
   (...)
    218     "infer": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg["value"]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg)
    544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
    545 trailing_shape = expected_shape[overlap_idx:]
--> 546 broadcast_shape = lax.broadcast_shapes(
    547     trailing_shape, tuple(dist_batch_shape)
    548 )
    549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape
    550 msg["fn"] = msg["fn"].expand(batch_shape)

    [... skipping hidden 1 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    169 result_shape = _try_broadcast_shapes(shape_list)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)]

I get that inputting samples for a deterministic site would lead to the model expecting a certain shape, but it does seem a bit awkward that the typical workflow with predictions requires some extra work if deterministics are involved.

I wonder if something like this is possible? https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/mcmc.py#L714C61-L714C62

from numpyro.

martinjankowiak avatar martinjankowiak commented on August 17, 2024

something like that sounds reasonable. the change in behavior was probably a mistake...

from numpyro.

kylejcaron avatar kylejcaron commented on August 17, 2024

@fehiepsi @martinjankowiak should the AutoGuide.sample_posterior() be changed as well? It seems more difficult to fix since many sample_posterior functions are unique to auto guides.

For example, the following workflow has the same problem :

guide = AutoNormal(model) 
svi = SVI(model, guide, optim=numpyro.optim.Adam(0.01), loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, X=X, y=y)

params = guide.sample_posterior(random.PRNGKey(0), params=svi_result.params)
pred_func = Predictive(model, params=params, num_samples=100)
preds = pred_func(random.PRNGKey(1), X=X[:250], y=None)

The solution for this seems to just including the guide and using SVI params instead, but I imagine some may be using the pattern above

pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=100)
preds = pred_func(random.PRNGKey(1),X[:n_preds])['eta']

from numpyro.

kylejcaron avatar kylejcaron commented on August 17, 2024

I think this pattern could be used with an exclude_deterministic arg in AutoGuide's

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

from numpyro.

kylejcaron avatar kylejcaron commented on August 17, 2024

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

Got it that makes sense to me - seems like it'd involve just replacing the substitute call in this line and L987, but let me know if I'm missing anything.

I'm happy to make an attempt at this, any name recommendations for the new effect handler?

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")

from numpyro.

kylejcaron avatar kylejcaron commented on August 17, 2024

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")

nice idea with the substitute_fn, just added a PR!

from numpyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.