Coder Social home page Coder Social logo

Comments (8)

fehiepsi avatar fehiepsi commented on August 17, 2024 1

I guess we can add a flag to control such behavior. Based on the flag, we can switch the order of operators in HMCGibbs.sample

from numpyro.

disadone avatar disadone commented on August 17, 2024

Do you think it would be easy? I wish try it first by modifying HMCGibbs.

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

Currently, in HMCGibbs.sample, we do gibbs update first (your ref link above) and run HMC.sample after that. It seems that this is the behavior that you want.

Could you clarify your comments here?

def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # NEED run first
    y = hmc_sites['y'] # NEED: initialized first not sample from model

I guess you don't want to use hmc_sites['y'] from the previous MCMC step? If so, you can do y = something_else.

from numpyro.

disadone avatar disadone commented on August 17, 2024

Yes, I do not want to hmc_sites['y']. I found the value could be overridden with the init_param value in MCMC if I switch the hmc and gibbs order as shown here.

    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs
        rng_key, rng_gibbs = random.split(state.rng_key)

        def potential_fn(z_gibbs, z_hmc):
            return self.inner_kernel._potential_fn_gen(
                *model_args, _gibbs_sites=z_gibbs, **model_kwargs
            )(z_hmc)

        z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z}
        z_hmc = {k: v for k, v in state.z.items() if k in state.hmc_state.z}
        model_kwargs_ = model_kwargs.copy()
        model_kwargs_["_gibbs_sites"] = z_gibbs

        z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

        z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

        if self.inner_kernel._forward_mode_differentiation:
            pe = potential_fn(z_gibbs, state.hmc_state.z)
            z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z)
        else:
            pe, z_grad = value_and_grad(partial(potential_fn, z_gibbs))(
                state.hmc_state.z
            )
        hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe)

        model_kwargs_["_gibbs_sites"] = z_gibbs
        hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_)

        z = {**z_gibbs, **hmc_state.z}

        return HMCGibbsState(z, hmc_state, rng_key)

I just wonder whether there is any unexpected side effects if I turn the sample function like this.

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.

from numpyro.

disadone avatar disadone commented on August 17, 2024

What did you switch? I guess I misunderstood your question. We do gibbs first and hmc update later, which seems like what you want.

Sorry for confusing. The order of these sentences:

z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

In the original file, without modification

z_hmc = self.inner_kernel.postprocess_fn(model_args, model_kwargs_)(z_hmc)

z_gibbs = self._gibbs_fn(
            rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc
        ) # switch the order of z_gibbs and z_hmc

z_hmc will not work through the sample part first and then pass it to self.gibbs_fn in the modified file.
I write a print in self-defined model at last and find that self.inner_kernel.postprocess_fn could trig model and change the z_hmc value. Though it seems that postprocess_fn is for postprocess not trigging sampling……

from numpyro.

fehiepsi avatar fehiepsi commented on August 17, 2024

The postprocess_fn is necessary to make sure that hmc samples are in the correct domain for the gibbs_fn to condition on. In most cases, it will transform unconstrained samples into constrained samples without triggering the model. But if your model has stochastic support, it is necessary to run the model to perform the transform correctly.

from numpyro.

disadone avatar disadone commented on August 17, 2024

Thank you, I understand the point!

from numpyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.