Coder Social home page Coder Social logo

flowjax's People

Contributors

danielward27 avatar mdmould avatar tennessee-wallaceh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

flowjax's Issues

Better support for bounded/constrained distributions

I am happy leaving it to the user to enforce constraints they require, rather than having explicitly constrained distributions, however:

  • We need a Log/Exp bijection
  • We should ensure that Distribution.log_prob calculations involving arctanh and log return -jnp.inf where appropriate (for samples out of support), rather than jnp.nan.
  • We should have an example of learning a distribution with bounded support

flowjax not available after install

Thanks for this package, along with the documentation it looks great and I'm excited to use it.

However, after installing with pip, I cannot import the package, which gives ModuleNotFoundError: No module named 'flowjax'. The same issue occurs when cloning and installing, then navigating away from the flowjax root directory.

To reproduce (e.g., in colab):

!pip install flowjax
import flowjax

Any ideas?

ENH: use jaxtyping

If you're building on equinox, might as well use jaxtyping for more detailed type hints, including shape information.

Loss function in train utils isn't "jitted"

The loss function defined on train_utils line 44 doesn't have jit applied to it.

This is then used in the computation of validation loss, which could significantly slow training particularly when the validation set is large.

A fix should be as simple as adding a eqx.filter_jit decoration to the loss.

Doctest fails

Sphinx doctest fails for a couple reasons. When running doctest

  1. Sphinx seems to stringify types which is incompatible with equinox (cannot use stringified abstract annotations).
  2. We set typing.GENERATING_DOCUMENTATION = True in conf.py to avoid expanding ArrayLike in the documentation (jaxtyping imports it as an Array instead). This change in imports lead to errors on isinstance checks.

Ideally doctest would be ran separately, outside a document generating context. At some point I might make the change to MKDocs, in which case another solution will need to be found for testing documentation anyway.

Transformation for conditioning variables

Hello Daniel, excellent package!

I would like to do conditional estimation, but apply a learnable transformation to the conditioning variables (the โ€œuโ€ in your example) before they are fed to the flow, and hopefully optimize the transformation as part of the fit.

Can I do it in flowjax without too much surgery?

Add docs section on serialization

Thanks for the great package! It would be nice to have a docs section that explains / or points to an explanation of how to serialize and unserialize the objects in this package: Bijections, Flows (and their weights), etc.

Concatenate transform incompatible with jax tracing

Calls to transform, inverse, etc. for Concatenate bijection break jax tracing (e.g., via jit), because of the int array split_idxs passed to array_split.

E.g., this fails:

import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate

bijections = (Affine(loc = jnp.zeros(1)),) * 2
bijection = Concatenate(bijections)

jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))

whereas Stack works (because it splits input arrays equally):

import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Stack

bijections = (Affine(),) * 2
bijection = Stack(bijections)

jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))

A sufficient fix seems to be to convert split_idxs to a sequence.

masked autoregressive flow with mixed transformer types

I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow function via removing the _add_default_permute, and modified the _flat_params_to_transformer in the MaskedAutoregressive class to apply an Identity transformer to the first dimension in the following way

    def _flat_params_to_transformer(self, params: Array, y_dim=1):
        """Reshape to dim X params_per_dim, then vmap."""
        dim = self.shape[-1]
        transformer_params = jnp.reshape(params, (dim, -1))
        transformer_params = transformer_params[y_dim:, :]
        transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
        return Concatenate(
            [Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
        )

My understanding is that in this way the masked_autoregressive_mlp will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!

Order of arguments to Distribution.sample

A bit of a minor one, but to me the Distribution.sample method feels a bit unnatural as it forces condition to be passed or sample_shape to be passed as a kwarg.

How about, Distribution.sample(key: jr.PRNGKey, sample_shape: Tuple[int] = (), condition: Optional[Array] = None), instead?

Then the user can simply do Distribution.sample(key, (n_samp,)).

This shouldn't be a breaking change, as any previous code would be using kwargs.
What do you think?

Consider adding orthogonal transformations

Instead of permutations, one can consider learnable linear transformations (See section 3.2 of the Papamakarios review.).
As far as I understand it, in theory this could allow us to learn which conditionals are easier to fit within a model during training.
Obviously this comes with a cost, so may not be practical in all cases.

I have had some decent results with simple versions of this, albeit in low dimensions (< 5), so it could be something worth adding.

Concatenate fails when stacking more than two bijections

There is a bug in the segmenting of inputs given to the transform/inverse methods of flowjax.bijections.Concatenate.

E.g., the following will fail, with the specific error depending on the shapes of the stacked bijections:

import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate

n = 3
bijections = (Affine(loc = jnp.zeros(1)),) * n
bijection = Concatenate(bijections)

bijection.transform(jnp.ones(n))

Infinite Loss When Using Uniform Base Distribution to Model Gaussian Data with MAF

I am new to your package and am playing with your Bounded Flow example.
I am using MAF with a rational quadratic spline transformer and want to use a standard uniform as a base distribution. If I simulate $x$ from a standard gaussian and train the flow with MLE, I immediately get an infinite loss and samples from the flow do not correctly recover the target.

nvars = 2
key, x_key = jr.split(jr.PRNGKey(0))
x = jr.normal(x_key, shape=(5000, nvars))  

key, subkey = jr.split(jr.PRNGKey(0))
base_distr = flowjax.distributions._StandardUniform((nvars,))

# Create the flow
untrained_flow = masked_autoregressive_flow(
    key=subkey,
    base_dist=base_distr,
    transformer=RationalQuadraticSpline(knots=8, interval=4),
)

key, subkey = jr.split(key)
# Train 
flow, losses = fit_to_data(
    key=subkey,
    dist=untrained_flow,
    x=x,
    learning_rate=5e-4,
    max_patience=10,
    max_epochs=70,
)

This happens as well if I use a Uniform with a larger support

base_distr = flowjax.distributions.Uniform(minval=jnp.ones(nvars)*-3, maxval=jnp.ones(nvars)*3)

I have as well tried to build an "unbounded uniform" class for my base distribution where I pass uniform samples through an inverse tanh the same way you do in the example

class UnboundedUniform(AbstractTransformed):
    base_dist: flowjax.distributions._StandardUniform
    bijection: Chain

    def __init__(self, shape):
        eps = 1e-7 
        self.base_dist = flowjax.distributions._StandardUniform(shape)
        affine_transformation = Affine(loc=-jnp.ones(shape) + eps, scale=(1 - eps) * jnp.ones(shape)*2)
        inverse_tanh_transformation = Invert(Tanh(shape=shape))
        self.bijection = Chain([affine_transformation, inverse_tanh_transformation])

base_distr = UnboundedUniform((nvars,))

but I get the same behavior. I would like to know whether this is expected and I am doing something wrong/silly or there is any workaround to this. Thank you for the help and all the good work!

BlockNeuralAutoregressiveFlow not initially normalized correctly

The log_prob evaluations of an untrained instance of BlockNeuralAutoregressiveFlow are not correctly normalized.

E.g.:

import jax
import jax.numpy as jnp
from flowjax.distributions import Normal
from flowjax.flows import BlockNeuralAutoregressiveFlow
import matplotlib.pyplot as plt

flow = BlockNeuralAutoregressiveFlow(
    key=jax.random.PRNGKey(0),
    base_dist=Normal(jnp.zeros(1)),
    cond_dim=None,
    nn_depth=1,
    nn_block_dim=1,
    flow_layers=1,
    invert=True,
    )

x = jnp.linspace(-100, 100, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))

The same is true if the support of the transformed distribution is explicitly bounded (this is unsurprising in hindsight, as the additional log-abs-det-Jacobian from the bounding bijection of course cannot account for any previous bijections), e.g.:

import jax
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, Chain, Tanh, BlockAutoregressiveNetwork
import matplotlib.pyplot as plt

bijections = [
    Invert(
        BlockAutoregressiveNetwork(
            key=jax.random.PRNGKey(0),
            dim=1,
            cond_dim=None,
            depth=1,
            block_dim=1,
            ),
        ),
    Tanh(shape=(1,)),
    ]

flow = Transformed(Normal(jnp.zeros(1)), Chain(bijections))

x = jnp.linspace(-2, 2, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))

Is this expected behaviour (perhaps related to initialization; Appendix C of https://arxiv.org/abs/1904.04676)? This does not occur for other flows, e.g., MaskedAutoregressiveFlow. In practice it's not really an issue, because the correct normalization is preserved once the flow is trained to some samples.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.