Comments (10)
Yeah, dwelling into autograd
documentation, then searching for similar pattern in jax.lax
, I come up with the following template, which works for some simple functions. ;)
from jax.interpreters import ad
from jax.lax import standard_unop, _float
def standard_gamma(x):
# implement forward
return x
def _standard_gamma_jvp_rule(g, ans, x):
# implement backward
return g * x
standard_gamma_p = standard_unop(_float, 'standard_gamma')
ad.defjvp2(standard_gamma_p, _standard_gamma_jvp_rule)
Edit: the above is not true (it is just applicable for XLA-compatible primitives), it seems that we need to define jvp_rule (forward-mode), transpose_rule (reverse-mode), batch_rule (vmap),... for custom primitives. I have implemented the sampler but the rules are still tricky for me.
from numpyro.
@fehiepsi - Just add your name to the distribution that you are working on so that we don't end up doing double work! :)
from numpyro.
Yes, I'll make a PR soon to not block us from adding more distributions.
from numpyro.
@neerajprad Have you worked on this? I can't quickly add these distributions in case it is not overlapping with what you are doing. :)
from numpyro.
@neerajprad Have you worked on this? I can't quickly add these distributions in case it is not overlapping with what you are doing. :)
Feel free to work on this. I only took a brief look. I think the blocker is that we need to add a gamma sampler to jax first before we can support distributions like dirichlet and beta. What do you think?
from numpyro.
we need to add a gamma sampler to jax first before we can support distributions like dirichlet and beta.
Agree! I am mostly interested in seeing performance on hmm so I would like to support dirichlet distribution. Will try to implement a gamma sampler first (mainly to learn). :)
from numpyro.
Hmm, looks like it is more complicated than I thought. Not because of understanding algorithms (Fritz already made it clear in his PR to pytorch), but because of my lacking knowledge on how jax/lax works (e.g. how to take gradient with shape parameters, how to implement if/else/while logics,...). I think that we'd need help from JAX devs on this.
from numpyro.
I believe you'll need to define the sampler as a primitive operation and register its gradient separately, and not try to make the sampler itself differentiable. It is possible to do that but the documentation is a bit lacking - google/jax#116. I think all we'll need is digamma to specify the gradient function which is already implemented in xla. What do you think? I can take a stab at it next week too.
from numpyro.
@fehiepsi - We can ask the jax folks if this is already on their roadmap. In the meantime, feel free to add a PR for the sampler. For HMC at least, we just need the logpdf methods which can be wrapped over easily, and even with just a non-reparametrized sampler we can experiment with other algorithms.
from numpyro.
No hurry, please take your time.
from numpyro.
Related Issues (20)
- mean_accept_prob significantly different after warmup HOT 8
- HSGP utility functions in the `contrib` module? HOT 2
- Add Pareto Smoothed Importance Sampling (PSIS) diagnostic method
- contrib.hsgp: support vector-valued kernel hyperparameters HOT 3
- [FR] Truncated Power Law distribution
- An auto guide's `_unpack_latent` and `_unpack_latent._inverse` don't use produce the same order HOT 1
- [FR] Utility for joint distributions HOT 14
- How can I gibbs before HMC/NUTS? HOT 8
- Large potential energy while using `HMCGibbs` at the initial stage HOT 3
- Inference Test Failing HOT 2
- Figure in AR2 example is not reproducible
- Got Problems When Computing Log Likelihoods in a Scan-Based VAR Model HOT 2
- Inference for `gaussian_hmm` is broken on latest jax version (0.4.30) HOT 1
- autocorrelation function HOT 5
- saving render_model() output to the desired file path HOT 5
- Object oriented wrapper API HOT 5
- Stress test utility for numpyro? HOT 2
- Samples are outside the support for DiscreteUniform distribution HOT 3
- Crash when using `TruncatedNormal` in `parallel` MCMC, but not in `sequential` MCMC HOT 5
- `nuts.get_extra_fields()["num_steps"]=0` after warmup HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from numpyro.