Coder Social home page Coder Social logo

Comments (10)

fehiepsi avatar fehiepsi commented on August 16, 2024 2

Yeah, dwelling into autograd documentation, then searching for similar pattern in jax.lax, I come up with the following template, which works for some simple functions. ;)

from jax.interpreters import ad
from jax.lax import standard_unop, _float

def standard_gamma(x):
    # implement forward
    return x

def _standard_gamma_jvp_rule(g, ans, x):
    # implement backward
    return g * x

standard_gamma_p = standard_unop(_float, 'standard_gamma')
ad.defjvp2(standard_gamma_p, _standard_gamma_jvp_rule)

Edit: the above is not true (it is just applicable for XLA-compatible primitives), it seems that we need to define jvp_rule (forward-mode), transpose_rule (reverse-mode), batch_rule (vmap),... for custom primitives. I have implemented the sampler but the rules are still tricky for me.

from numpyro.

neerajprad avatar neerajprad commented on August 16, 2024 1

@fehiepsi - Just add your name to the distribution that you are working on so that we don't end up doing double work! :)

from numpyro.

fehiepsi avatar fehiepsi commented on August 16, 2024 1

Yes, I'll make a PR soon to not block us from adding more distributions.

from numpyro.

fehiepsi avatar fehiepsi commented on August 16, 2024

@neerajprad Have you worked on this? I can't quickly add these distributions in case it is not overlapping with what you are doing. :)

from numpyro.

neerajprad avatar neerajprad commented on August 16, 2024

@neerajprad Have you worked on this? I can't quickly add these distributions in case it is not overlapping with what you are doing. :)

Feel free to work on this. I only took a brief look. I think the blocker is that we need to add a gamma sampler to jax first before we can support distributions like dirichlet and beta. What do you think?

from numpyro.

fehiepsi avatar fehiepsi commented on August 16, 2024

we need to add a gamma sampler to jax first before we can support distributions like dirichlet and beta.

Agree! I am mostly interested in seeing performance on hmm so I would like to support dirichlet distribution. Will try to implement a gamma sampler first (mainly to learn). :)

from numpyro.

fehiepsi avatar fehiepsi commented on August 16, 2024

Hmm, looks like it is more complicated than I thought. Not because of understanding algorithms (Fritz already made it clear in his PR to pytorch), but because of my lacking knowledge on how jax/lax works (e.g. how to take gradient with shape parameters, how to implement if/else/while logics,...). I think that we'd need help from JAX devs on this.

from numpyro.

neerajprad avatar neerajprad commented on August 16, 2024

I believe you'll need to define the sampler as a primitive operation and register its gradient separately, and not try to make the sampler itself differentiable. It is possible to do that but the documentation is a bit lacking - google/jax#116. I think all we'll need is digamma to specify the gradient function which is already implemented in xla. What do you think? I can take a stab at it next week too.

from numpyro.

neerajprad avatar neerajprad commented on August 16, 2024

@fehiepsi - We can ask the jax folks if this is already on their roadmap. In the meantime, feel free to add a PR for the sampler. For HMC at least, we just need the logpdf methods which can be wrapped over easily, and even with just a non-reparametrized sampler we can experiment with other algorithms.

from numpyro.

neerajprad avatar neerajprad commented on August 16, 2024

No hurry, please take your time.

from numpyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.