This repository contains a minimal implementation of some concepts related to stochastic interpolants in JAX, based on this paper by Michael S. Albergo, Nicholas M. Boffi, Eric Vanden-Eijnden.
Disclaimer: This implementation is meant to be didactic. For a more functional version (in Pytorch), see the repository published by the authors of the paper here.
Before installing this project,
and after creating & activating your virtual environment,
you must install JAX yourself because CPU and GPU backends require different installation commands.
See here for instructions.
For the small examples, pip install jax[cpu]
will suffice.
For the bigger demos, a GPU is helpful.
Then, move to the root of the directory and run
pip install .
This command installs all requirements (Flax, Optax, etc.).
Then, find the content as
from stochint import *
Find the demos in demos/
.
Thanks to Paul Jeha (@pablo2909) for teaching us how to write a name with 2d samples.