Coder Social home page Coder Social logo

muskanmahajan37 / torchsde Goto Github PK

View Code? Open in Web Editor NEW

This project forked from google-research/torchsde

0.0 0.0 0.0 3.81 MB

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

License: Apache License 2.0

Python 90.30% C++ 9.70%

torchsde's Introduction

PyTorch Implementation of Differentiable SDE Solvers

This codebase provides stochastic differential equation (SDE) solvers with GPU support and efficient sensitivity analysis. Similar to torchdiffeq, algorithms in this repository are fully supported to run on GPUs.


Installation

pip install git+https://github.com/google-research/torchsde.git

Examples

demo.ipynb in the examples folder is a short guide on how one may use the codebase for solving SDEs without considering gradient computation. It covers subtle points such as fixing the randomness in the solver and the consequence of noise types. Unlike with ordinary differential equation (ODE), where general high-order solvers can be applied, the numerical efficiency of SDE solvers highly depends on the assumption on the noise (i.e. assumptions on the diffusion function). We encourage those who are interested in using this codebase to take a peek at the notebook.

latent_sde.py in the examples folder learns a latent stochastic differential equation (cf. Section 5 of [1] and references therein). The example fits an SDE to data, all the while regularizing it to be like an Ornstein-Uhlenbeck prior process. The model can be loosely viewed as a variational autoencoder with its prior and approximate posterior being SDEs. Note, the example contains many simplifications and is meant to demonstrate basic usage regarding gradient computation.

To run the latent SDE example, execute the following command from the root folder of this repo:

python3 -m examples.latent_sde \
  --adjoint \
  --adaptive \
  --rtol 1e-2 \
  --atol 1e-3 \
  --train-dir ${TRAIN_DIR}

Once in a while, the program writes a figure to TRAIN_DIR, an environment variable that specifies a location on disk. Training should stabilize after 500 iterations with the default hyperparameters.

Even though the example can be run with just CPUs, we recommend running it on a machine with GPUs.

Usage

The central functions of interest are sdeint and sdeint_adjoint. They can be imported as follows:

from torchsde import sdeint, sdeint_adjoint

Integrating SDEs

SDEs are defined with two vector fields, named respectively the drift and diffusion functions. The drift is analogous to the vector field in ODEs, whereas the diffusion controls how the Brownian motion affects the system.

By default, solvers in the codebase integrate SDEs that are torch.nn.Module objects with the drift function f and diffusion function g. For instance, the Geometric Brownian motion can be integrated as follows:

import torch
from torchsde import sdeint


class SDE(torch.nn.Module):

    def __init__(self, mu, sigma):
        super().__init__()
        self.noise_type="diagonal"
        self.sde_type = "ito"

        self.mu = mu
        self.sigma = sigma

    def f(self, t, y):
        return self.mu * y

    def g(self, t, y):
        return self.sigma * y

batch_size, d, m = 4, 1, 1  # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)
y0 = torch.zeros(batch_size, d).fill_(0.1)  # Initial state.
ts = torch.linspace(0, 1, 20)
ys = sdeint(geometric_bm, y0, ts)

To ensure the solvers work, f and g must take in the time t and state y in the specific order. Most importantly, the SDE class must have the attributes sde_type and noise_type. The noise type of the SDE determines what numerical algorithms may be used. See demo.ipynb for more on this.

Possible noise types and explanations when state dimension=d

  • diagonal: The diffusion g is element-wise and has output size (d,). The Brownian motion is d-dimensional.
  • additive: The diffusion g is constant w.r.t. y and has output size (d, m). The Brownian motion is m-dimensional.
  • scalar: The diffusion g has output size (d, 1). The Brownian motion is 1-dimensional.
  • general: The diffusion g has output size (d, m). The Brownian motion is m-dimensional.

In practice, we found diagonal and additive to produce a good trade-off between model flexibility and computational efficiency, and we recommend sticking to these two noise types if possible.

Integrating SDEs with KL penalty

During modeling, fitting SDEs directly to data will likely give degenerate solutions that are close to being deterministic. One benefit of using SDEs is when they are treated as latent variables in a Bayesian formulation. With an appropriate prior SDE, a Kullback–Leibler (KL) divergence on the space of sample paths can be estimated efficiently. To achieve this, we can give the SDE class an additional method h(t, y), defining the drift of the prior. The KL penalty can be estimated when we specify the argument logqp=True to sdeint:

import torch

from torchsde import sdeint

class SDE(torch.nn.Module):

    def __init__(self, mu, sigma):
        super().__init__()
        self.noise_type="diagonal"
        self.sde_type = "ito"

        self.mu = mu
        self.sigma = sigma

    def f(self, t, y):
        return self.mu * y

    def g(self, t, y):
        return self.sigma * y

    def h(self, t, y):
        return self.mu * y * 0.5

batch_size, d, m = 4, 1, 1  # State dimension d, Brownian motion dimension m.
geometric_bm = SDE(mu=0.5, sigma=1)
y0 = torch.zeros(batch_size, d).fill_(0.1)  # Initial state.
ts = torch.linspace(0, 1, 20)
ys, logqp = sdeint(geometric_bm, y0, ts, logqp=True)  # Also returns estimated KL.

To switch to using the adjoint formulation for memory efficient gradient computation, all we need is to replace sdeint with sdeint_adjoint in the above code snippets.

Keyword arguments of sdeint

  • bm: A BrownianPath or BrownianTree object. Optionally include to seed the solver's computation.
  • logqp: If True, also return the Radon-Nikodym derivative, which is a log-ratio penalty across the whole path.
  • method: One of the solvers listed below.
  • dt: A float for the constant step size or initial step size for adaptive time-stepping.
  • adaptive: If True, use adaptive time-stepping.
  • rtol: Relative tolerance.
  • atol: Absolute tolerance.
  • dt_min: Minimum step size.

List of SDE solvers

Note that stochastic Runge-Kutta methods is a class of numerical methods, and the precise formulation for one noise type may be much different than that for another. Internally, sdeint selects the algorithm based on the attribute noise_type of the SDE object.

Known issues

  • Existing solvers prioritize having high strong order. High weak order solvers will be included in the future.
  • Existing solvers are based on Itô SDEs. Solvers for Stratonovich SDEs will be included in the future.
  • Adjoint mode is currently not supported for SDEs with noise type scalar and general. We expect the first case to be fixed soon. The second case requires more work and will be fixed after efficient Stratonovich solvers are in place.
  • Unlike the adjoint sensitivity method, our proposed stochastic adjoint sensitivity method is, to the best of our knowledge, a new numerical method. Theoretical properties in terms of its interaction with adaptive time-stepping is still largely unknown, even though we have found the combination to typically work in practice.

References

[1] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." International Conference on Artificial Intelligence and Statistics. 2020. [arxiv]


This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

If you found this codebase useful in your research, please consider citing

@article{li2020scalable,
  title={Scalable gradients for stochastic differential equations},
  author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky T. Q. and Duvenaud, David},
  journal={International Conference on Artificial Intelligence and Statistics},
  year={2020}
}

torchsde's People

Contributors

lxuechen avatar patrick-kidger avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.