Coder Social home page Coder Social logo

Comments (13)

martinjankowiak avatar martinjankowiak commented on July 17, 2024 1

i guess to do that you'd to override sample and define a custom jvp or jvp rule

from numpyro.

fehiepsi avatar fehiepsi commented on July 17, 2024

More context about this issue:

Dirichlet sampler is available in #81. However, the derivative w.r.t. to concentration parameter is computed based on pathwise derivative of the gamma sampler which generates dirichlet samples. It would be better to have a pathwise derivative implementation for Dirichlet distribution on its own. An algorithm for implementing can be found in this paper.

from numpyro.

fehiepsi avatar fehiepsi commented on July 17, 2024

The current Dirichlet sampler has a reparameterized gradient but its variance would be higher than the version in PyTorch. Given that we mostly rely on the upstream jax samplers now, it is better to use the current high variance version and raise the request to jax team (if needed) instead.

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

Hello,

I'm trying to sample from a Dirichlet with 4 dimensions, but over a batch of thousands of elements. The gradients are very slow compared to PyTorch. Has there been any progress on speeding this up ?

from numpyro.

fehiepsi avatar fehiepsi commented on July 17, 2024

Could you try TFP Dirichlet distribution? IIRC it has faster dirichlet sampler.

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

I'm getting nan's there related to dirichlet sampling where I'm not with jax.random.dirichlet/numpyro!...

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

I can confirm the sampling if faster in TFP, but I'll get NaN's during training, which I haven't narrowed down to a gradient issue or a sampling issue.

from numpyro.

fehiepsi avatar fehiepsi commented on July 17, 2024

You might want to create a wrapper to clip Dirichlet samples. Something like

class RobustDirichlet(...):
    def sample(self, ...):
        samples = super().sample(...)
        return clip(samples)

from numpyro.

martinjankowiak avatar martinjankowiak commented on July 17, 2024

@adamgayoso it might help to make sure your parameter is bounded away from zero e.g. concentration=0.01 + positive_param

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

Thank you both for the suggestions, although I'm still getting nan's. I'd need more time to inspect if it's the gradients, but do you know if I can stitch the jax random dirichlet gradient fn with tfp gamma sampler?

from numpyro.

martinjankowiak avatar martinjankowiak commented on July 17, 2024
from tensorflow_probability.substrates.jax import distributions as tfd

def model():
    numpyro.sample("x", tfd.Normal(0, 1))

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

Thank you. Where is that happening?

I see this wrapper: https://num.pyro.ai/en/stable/distributions.html#numpyro.contrib.tfp.distributions.TFPDistribution

But not where a custom gradient fn is being set.

from numpyro.

adamgayoso avatar adamgayoso commented on July 17, 2024

@fehiepsi would you be able to check out google/jax#12943 ?

The sampling wasn't the bottleneck, it was actually the lax map for the grads on cpu. I also fixed the lax cond in the sampling loop

from numpyro.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.