Coder Social home page Coder Social logo

Comments (12)

fehiepsi avatar fehiepsi commented on August 17, 2024

Hi @amifalk, those autoguides are not designed to be composed with vmap after the construction because it needs initialization (to inspect the model and generate something like prototype_trace etc.). Something like this will work

def init(...):
    guide = AutoDelta(...)
    return guide.init(...)

init_state = jax.vmap(init)(...)

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

I think there's still an issue here. When svi.init is called, it initializes both the model and the guide for the first time, which should set up the prototype trace. Batched model fitting works with AutoGuides if I vmap the SVI methods in all cases except when there is both a deterministic site and the guide is based on a blocked model.

The suggested approach yields the same error as before (though AutoGuides are not registered as pytrees so they cannot be returned after calling vmap).

def guide_init(rng_seed):
   guide = AutoDelta(block(seed(model, rng_seed=rng_seed), hide=['b']))
   seed(guide, rng_seed=rng_seed)()
   
   return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init)(keys) # this works

def guide_init_deterministic(rng_seed):
   guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=rng_seed), hide=['b']))
   seed(guide, rng_seed=rng_seed)()
   
   return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init_deterministic)(keys) # tracer error

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

@fehiepsi I've traced the source to this while loop. If I set _DISABLE_CONTROL_FLOW_PRIM = True, vmapping the svi.init method works. However, vmapping the guide initialization yields a new error in the while loop:

This BatchTracer with object id 140305711093520 was created on line:
  /home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:357:15 (find_valid_initial_params.<locals>.cond_fn)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

If we use Python while loop, then the condition needs to be a Python value like True or False. having a jax object there won't work. What is your usage case by the way?

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

I have a blocked model with a deterministic site that I'm trying to perform some simulation studies on. I want to see how variations in the structure of the dataset / model hyperparameters affect the performance, and I also want to be able to select the best result over multiple initializations. It's very slow to do this sequentially (for a small grid of hyperparams it took around 40 minutes), but after vmapping/pmapping with GPU I can get the entire grid to run in parallel. In my case it reduced the fitting time to 7 seconds.

Unfortunately, if I try to vmap the blocked model with deterministic sites present, it throws this error, so I have to instead recompute the deterministic sites at the end of model fitting.

In my case, I need to block the model to define an AutoGuide that is compatible with enumeration (blocking out the enumerated sites), but this would likely also be a problem for people using AutoGuideList.

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

I think you can do something like

def run_svi(...):
    svi = ...
    svi_result = svi.run(...)
    return svi_result

svi_results = vmap(run_svi)(...)

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

Unfortunately this still seems to throw the same tracer error.

def run_svi(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

def run_svi_deterministic(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

keys = random.split(random.PRNGKey(0), 2)

jax.vmap(run_svi)(keys) # works
jax.vmap(run_svi_deterministic)(keys) # tracer error from the while loop in find_valid_initial_params

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

Can we make it so that AutoGuides only collect non-enumerated model sample sites? This wouldn't fix the problem for all blocked models, but it would make collecting deterministic sites possible under batched svi for my use-case.

I think this would only have to be a one-liner change here-ish where we just ignore sample sites in the prototype trace that have site['infer'].get('enumerate') == True. That would also make the syntax for defining AutoGuides for enumerated models much simpler, e.g. just AutoGuide(model), instead of
AutoGuide(block(seed(model, rng_seed=0), hide=["enumerated_site_1", "enumerated_site_2", ...]))

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

Thanks @amifalk! There is indeed leakage here with the seed handler. I haven't been able to figure out why yet. Posting here for reference

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    return numpyro.infer.util.initialize_model(key, numpyro.handlers.seed(model, rng_seed=0))[0]

with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

@fehiepsi With that example, I was able to narrow the source of the bug further - thanks! The while loop of _find_valid_params closes over the seeded model, but it also traces the model during its calls to potential_fn. I think the fact that the trace is seeing a rng key from the global call to seed is causing the error. Here's a minimal example.

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    seeded = numpyro.handlers.seed(model, rng_seed=0)

    def cond_fn(state):
        i, num = state
        return i < 10 

    def body_fn(state):
        i, num = state
        
        numpyro.handlers.trace(seeded).get_trace() # this references the global rng values in a jitted context
        # equivalently num = numpyro.handlers.trace(seeded).get_trace()['a']['value'] will raise an error        
        return (i + 1, num)

    return jax.lax.while_loop(cond_fn, body_fn, (0, 0))
    
with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))

You can also verify this by replacing potential_fn in numpyro.infer.util.find_valid_initial_params with a placeholder that just returns a constant number.

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

I think we figured it out. thanks for the examples!

seed(model) is an instance of a seed class which has mutable state. A fix for it is to close the seeded model into a function like

def seeded_model(*args, **kwargs):
    return seed(model, rng_seed=random.PRNGKey(0))(*args, **kwargs)

This way each time we call the model, a new instance of the seed handler will be created. Could you check if it works for your usage case? I'll think of a long term solution (maybe improve docstring for this).

from numpyro.

amifalk avatar amifalk commented on August 17, 2024

Yes, this fixed it! Not sure if there's any interest in adding to NumPyro, but here's the pattern for batching SVI: https://gist.github.com/amifalk/eb377a243b046105dc00beda79441b22

from numpyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.