danielward27 / flowjax Goto Github PK
View Code? Open in Web Editor NEWHome Page: https://danielward27.github.io/flowjax/
License: MIT License
Home Page: https://danielward27.github.io/flowjax/
License: MIT License
I am happy leaving it to the user to enforce constraints they require, rather than having explicitly constrained distributions, however:
Log/Exp
bijectionDistribution.log_prob
calculations involving arctanh and log return -jnp.inf
where appropriate (for samples out of support), rather than jnp.nan
.Hi! I have a private fork of flowjax, in which I implement cross-validation for the test/train split of the data, so as to not overfit to the test data (eg for fit_to_data/fit_to_variational_target). It works quite well. Would you be open to me contributing to flowjax/putting up a pull request for this?
Thanks for this package, along with the documentation it looks great and I'm excited to use it.
However, after installing with pip, I cannot import the package, which gives ModuleNotFoundError: No module named 'flowjax'
. The same issue occurs when cloning and installing, then navigating away from the flowjax root directory.
To reproduce (e.g., in colab):
!pip install flowjax
import flowjax
Any ideas?
If you're building on equinox
, might as well use jaxtyping
for more detailed type hints, including shape information.
The loss
function defined on train_utils
line 44 doesn't have jit applied to it.
This is then used in the computation of validation loss, which could significantly slow training particularly when the validation set is large.
A fix should be as simple as adding a eqx.filter_jit
decoration to the loss.
Similar to the train_flow
helper function, it would be useful to have a fit_flow_vi
(naming), helper function.
For now this can be a simple implementation, with potentially more robust implementations to follow (such as https://arxiv.org/pdf/1802.02538.pdf).
@danielward27 I am working on this on my fork 60-add-vi-helper
I've made a lot of use of your really nice package for several projects that I hope to publish in the not-too-distant future. Would you consider making it citable, e.g., via a Zenodo DOI (https://docs.github.com/en/repositories/archiving-a-github-repository/referencing-and-citing-content, which integrates nicely with GitHub releases)?
Sphinx doctest fails for a couple reasons. When running doctest
typing.GENERATING_DOCUMENTATION = True
in conf.py to avoid expanding ArrayLike
in the documentation (jaxtyping imports it as an Array instead). This change in imports lead to errors on isinstance checks.Ideally doctest would be ran separately, outside a document generating context. At some point I might make the change to MKDocs, in which case another solution will need to be found for testing documentation anyway.
Hello Daniel, excellent package!
I would like to do conditional estimation, but apply a learnable transformation to the conditioning variables (the โuโ in your example) before they are fed to the flow, and hopefully optimize the transformation as part of the fit.
Can I do it in flowjax without too much surgery?
Thanks for the great package! It would be nice to have a docs section that explains / or points to an explanation of how to serialize and unserialize the objects in this package: Bijections, Flows (and their weights), etc.
Calls to transform, inverse, etc. for Concatenate bijection break jax tracing (e.g., via jit), because of the int array split_idxs
passed to array_split
.
E.g., this fails:
import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate
bijections = (Affine(loc = jnp.zeros(1)),) * 2
bijection = Concatenate(bijections)
jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))
whereas Stack works (because it splits input arrays equally):
import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Stack
bijections = (Affine(),) * 2
bijection = Stack(bijections)
jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))
A sufficient fix seems to be to convert split_idxs
to a sequence.
I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow
function via removing the _add_default_permute
, and modified the _flat_params_to_transformer
in the MaskedAutoregressive
class to apply an Identity transformer to the first dimension in the following way
def _flat_params_to_transformer(self, params: Array, y_dim=1):
"""Reshape to dim X params_per_dim, then vmap."""
dim = self.shape[-1]
transformer_params = jnp.reshape(params, (dim, -1))
transformer_params = transformer_params[y_dim:, :]
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
return Concatenate(
[Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
)
My understanding is that in this way the masked_autoregressive_mlp
will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!
This is a feature proposal to add more distributions to the package.
A bit of a minor one, but to me the Distribution.sample
method feels a bit unnatural as it forces condition
to be passed or sample_shape
to be passed as a kwarg.
How about, Distribution.sample(key: jr.PRNGKey, sample_shape: Tuple[int] = (), condition: Optional[Array] = None)
, instead?
Then the user can simply do Distribution.sample(key, (n_samp,))
.
This shouldn't be a breaking change, as any previous code would be using kwargs.
What do you think?
Instead of permutations, one can consider learnable linear transformations (See section 3.2 of the Papamakarios review.).
As far as I understand it, in theory this could allow us to learn which conditionals are easier to fit within a model during training.
Obviously this comes with a cost, so may not be practical in all cases.
I have had some decent results with simple versions of this, albeit in low dimensions (< 5), so it could be something worth adding.
There is a bug in the segmenting of inputs given to the transform
/inverse
methods of flowjax.bijections.Concatenate
.
E.g., the following will fail, with the specific error depending on the shapes of the stacked bijections:
import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate
n = 3
bijections = (Affine(loc = jnp.zeros(1)),) * n
bijection = Concatenate(bijections)
bijection.transform(jnp.ones(n))
I am new to your package and am playing with your Bounded Flow example.
I am using MAF with a rational quadratic spline transformer and want to use a standard uniform as a base distribution. If I simulate
nvars = 2
key, x_key = jr.split(jr.PRNGKey(0))
x = jr.normal(x_key, shape=(5000, nvars))
key, subkey = jr.split(jr.PRNGKey(0))
base_distr = flowjax.distributions._StandardUniform((nvars,))
# Create the flow
untrained_flow = masked_autoregressive_flow(
key=subkey,
base_dist=base_distr,
transformer=RationalQuadraticSpline(knots=8, interval=4),
)
key, subkey = jr.split(key)
# Train
flow, losses = fit_to_data(
key=subkey,
dist=untrained_flow,
x=x,
learning_rate=5e-4,
max_patience=10,
max_epochs=70,
)
This happens as well if I use a Uniform with a larger support
base_distr = flowjax.distributions.Uniform(minval=jnp.ones(nvars)*-3, maxval=jnp.ones(nvars)*3)
I have as well tried to build an "unbounded uniform" class for my base distribution where I pass uniform samples through an inverse tanh the same way you do in the example
class UnboundedUniform(AbstractTransformed):
base_dist: flowjax.distributions._StandardUniform
bijection: Chain
def __init__(self, shape):
eps = 1e-7
self.base_dist = flowjax.distributions._StandardUniform(shape)
affine_transformation = Affine(loc=-jnp.ones(shape) + eps, scale=(1 - eps) * jnp.ones(shape)*2)
inverse_tanh_transformation = Invert(Tanh(shape=shape))
self.bijection = Chain([affine_transformation, inverse_tanh_transformation])
base_distr = UnboundedUniform((nvars,))
but I get the same behavior. I would like to know whether this is expected and I am doing something wrong/silly or there is any workaround to this. Thank you for the help and all the good work!
It could be interesting to implement the approach outlined in the following papers:
Essentially the idea is to avoid accumulating gradients from the score term of the ELBO loss, which have expectation 0.
The log_prob
evaluations of an untrained instance of BlockNeuralAutoregressiveFlow
are not correctly normalized.
E.g.:
import jax
import jax.numpy as jnp
from flowjax.distributions import Normal
from flowjax.flows import BlockNeuralAutoregressiveFlow
import matplotlib.pyplot as plt
flow = BlockNeuralAutoregressiveFlow(
key=jax.random.PRNGKey(0),
base_dist=Normal(jnp.zeros(1)),
cond_dim=None,
nn_depth=1,
nn_block_dim=1,
flow_layers=1,
invert=True,
)
x = jnp.linspace(-100, 100, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))
The same is true if the support of the transformed distribution is explicitly bounded (this is unsurprising in hindsight, as the additional log-abs-det-Jacobian from the bounding bijection of course cannot account for any previous bijections), e.g.:
import jax
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, Chain, Tanh, BlockAutoregressiveNetwork
import matplotlib.pyplot as plt
bijections = [
Invert(
BlockAutoregressiveNetwork(
key=jax.random.PRNGKey(0),
dim=1,
cond_dim=None,
depth=1,
block_dim=1,
),
),
Tanh(shape=(1,)),
]
flow = Transformed(Normal(jnp.zeros(1)), Chain(bijections))
x = jnp.linspace(-2, 2, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))
Is this expected behaviour (perhaps related to initialization; Appendix C of https://arxiv.org/abs/1904.04676)? This does not occur for other flows, e.g., MaskedAutoregressiveFlow
. In practice it's not really an issue, because the correct normalization is preserved once the flow is trained to some samples.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.