Comments (13)
i guess to do that you'd to override sample
and define a custom jvp
or jvp
rule
from numpyro.
More context about this issue:
Dirichlet sampler is available in #81. However, the derivative w.r.t. to concentration parameter is computed based on pathwise derivative of the gamma sampler which generates dirichlet samples. It would be better to have a pathwise derivative implementation for Dirichlet distribution on its own. An algorithm for implementing can be found in this paper.
from numpyro.
The current Dirichlet sampler has a reparameterized gradient but its variance would be higher than the version in PyTorch. Given that we mostly rely on the upstream jax samplers now, it is better to use the current high variance version and raise the request to jax team (if needed) instead.
from numpyro.
Hello,
I'm trying to sample from a Dirichlet with 4 dimensions, but over a batch of thousands of elements. The gradients are very slow compared to PyTorch. Has there been any progress on speeding this up ?
from numpyro.
Could you try TFP Dirichlet distribution? IIRC it has faster dirichlet sampler.
from numpyro.
I'm getting nan's there related to dirichlet sampling where I'm not with jax.random.dirichlet/numpyro!...
from numpyro.
I can confirm the sampling if faster in TFP, but I'll get NaN's during training, which I haven't narrowed down to a gradient issue or a sampling issue.
from numpyro.
You might want to create a wrapper to clip Dirichlet samples. Something like
class RobustDirichlet(...):
def sample(self, ...):
samples = super().sample(...)
return clip(samples)
from numpyro.
@adamgayoso it might help to make sure your parameter is bounded away from zero e.g. concentration=0.01 + positive_param
from numpyro.
Thank you both for the suggestions, although I'm still getting nan's. I'd need more time to inspect if it's the gradients, but do you know if I can stitch the jax random dirichlet gradient fn with tfp gamma sampler?
from numpyro.
from tensorflow_probability.substrates.jax import distributions as tfd
def model():
numpyro.sample("x", tfd.Normal(0, 1))
from numpyro.
Thank you. Where is that happening?
I see this wrapper: https://num.pyro.ai/en/stable/distributions.html#numpyro.contrib.tfp.distributions.TFPDistribution
But not where a custom gradient fn is being set.
from numpyro.
@fehiepsi would you be able to check out google/jax#12943 ?
The sampling wasn't the bottleneck, it was actually the lax map for the grads on cpu. I also fixed the lax cond in the sampling loop
from numpyro.
Related Issues (20)
- `TransformedDistribution` support too broad when using `AffineTransform` transformation? HOT 2
- Correct control_flow.cond usage HOT 6
- Factor Analysis/PPCA Tutorial HOT 1
- Censoring Example HOT 1
- numpyro.deterministic static on infer.Predictive HOT 13
- [FR] Support for different supports in component distributions for mixture models HOT 5
- ImportError: cannot import name 'CAR' from "numpyro.distrubutions.continuous' HOT 2
- Use biased autocorrelation estimate by default HOT 1
- mean_accept_prob significantly different after warmup HOT 8
- HSGP utility functions in the `contrib` module? HOT 2
- Add Pareto Smoothed Importance Sampling (PSIS) diagnostic method
- contrib.hsgp: support vector-valued kernel hyperparameters HOT 3
- [FR] Truncated Power Law distribution
- An auto guide's `_unpack_latent` and `_unpack_latent._inverse` don't use produce the same order HOT 1
- [FR] Utility for joint distributions HOT 14
- How can I gibbs before HMC/NUTS? HOT 8
- Large potential energy while using `HMCGibbs` at the initial stage HOT 3
- Inference Test Failing HOT 2
- Figure in AR2 example is not reproducible
- Got Problems When Computing Log Likelihoods in a Scan-Based VAR Model HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from numpyro.