Coder Social home page Coder Social logo

million-dimension-prob-ode-solver-experiments's People

Contributors

nathanaelbosch avatar pnkraemer avatar schmidtjonathan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

million-dimension-prob-ode-solver-experiments's Issues

Lorenz versions

The different lorenz versions do not coincide perfectly.

There is also a bad default initial value.

Error measures

The experiment scripts introduced in #7 and #6 do not contain measures of approximation error.
This might be a useful quantity to report, too.

Experiment 1 pre-final Todos

In my opinion, the following changes would make experiment 1 paper ready:

  • always jit, but evaluate multiple orders; i suggest nu=2 ("no" init, i.e. closed form init), nu=4 (rk init), and nu=8 (TM init).
  • push the KroneckerEK0 (and soon the DiagonalEK1) dimension really high
  • add a plot of lorenz to the final figure

For now let's try to push every solver as far as possible -- this will hopefully change soon (as in that the solvers get more efficient).

Experiment 2 ToDo's

There are a few reasons why we think the plots yielded by experiment 2 look not entirely, as we expect:

  1. Use a fully JITted version of the run, in order to remove remaining noise in the run-time
  2. Track the number of evaluations of each solve
  3. The EK1 is at the moment faster than the EK0, because it is better optimized
  4. The diagonal EK1 builds the full Jacobian, which can (and will) be avoided
  5. We could neglect the cost of the initialization, which would probably reduce the overall runtime of each run
  6. We would probably see more interesting results for much higher dimensions than 10 (maybe use three different scales for the dimensions for the plots, e.g. 10, 50, 100)

Is this the final repo name?

Just wanted to raise the question since it can still be changed, but in general I think I'm fine with this repo name.

The only alternative that comes to my mind is "ode-filter" instead of "ode-solver"?

But we can also keep it as it is, again I just wanted to make sure that we consciously decide on a public repo name.

multiple results folders

At the moment, the results are stored in results//results.{csv, pdf}. There is also a directory data/ and a directory figures/, which stem from the initial scaffolding and must be removed to clarify that there is one way of doing things.

Slow jitting benchmarks

#14 introduces a benchmark-option that just-in-time-compiles the attempt_step functions.
This is for two reasons:

  1. It is interesting to see how much (wall-clock) run time one can gain by using a few additional tricks
  2. The run-time curves will look less noisy ("cleaner").

Since this makes the experiment really slow (compilation takes a while), it is turned off by default and should only be turned on when desired.

If we notice that we never turn it on, let's remove it again!

Benchmark for stability loss of EK1-variants

Do we only lose UQ, or (more likely) stability.Likely depends onhowthe EK1 is made faster: diago-nal jacobians will screw us, full jacobians with trunca-tions might work. Check on stiff van-der-pol? Maybe this would be the only section with a chi2 plot.

Write benchmark for an artificial, sterilised problem

Goal: compare dimension of the problem (x-axis) to runtime (average per step?) on the y-axis. Fixed step. Perhaps different orders. Show that (a) the improvements over the vanilla implementation are significant; and (b) that the complexities from the propositions earlier are attained

From a research perspective, show RK45 and radau from scipy as well (where are we now?) Perhaps we need to show overall time, not time per step as well?

Experiment runners

I have a feeling that perhaps there is a bit of incoming code duplication in the experiments.
If we take for example the experiment runner in 1_sterilised...

class SterilisedExperiment:
    def __init__(
        self,
        method,
        num_derivatives,
        ode_dimension,
        hyper_param_dict,
        jit,
        num_repetitions,
    ) -> None:
        self.hyper_param_dict = hyper_param_dict

        self.method = method
        self.num_derivatives = num_derivatives
        self.ode_dimension = ode_dimension

        self.ivp = problems.lorenz96_jax(
            params=(ode_dimension, hyper_param_dict["forcing"]),
            t0=hyper_param_dict["t0"],
            tmax=hyper_param_dict["tmax"],
            y0=jnp.arange(ode_dimension) * 1.0,
        )

        self.solver = tornado.ivpsolve._SOLVER_REGISTRY[method](
            ode_dimension=ode_dimension,
            steprule=tornado.step.ConstantSteps(hyper_param_dict["dt"]),
            num_derivatives=num_derivatives,
        )
        self.init_state = self.initial_state()

        self.result = dict()

        # Whether the step function is jitted before timing
        self.jit = jit

        # How often each experiment is run
        self.num_repetitions = num_repetitions

    def initial_state(self):
        m0, sc0 = tornado.init.stack_initial_state_jac(
            f=self.ivp.f, df=self.ivp.df, y0=self.ivp.y0, t0=self.ivp.t0, num_derivatives=self.num_derivatives
        )
        num_steps = self.num_derivatives + 1
        ts, ys = tornado.init.rk_data(f=self.ivp.f, t0=self.ivp.t0, dt=0.01, num_steps=num_steps, y0=self.ivp.y0, method="RK45")
        m, sc = tornado.init.rk_init_improve(
            m=m0,
            sc=sc0,
            t0=self.ivp.t0,
            ts=ts,
            ys=ys,
        )
        if isinstance(self.solver, (tornado.ek0.ReferenceEK0, tornado.ek1.ReferenceEK1)):
            y = tornado.rv.MultivariateNormal(mean=m.reshape((-1,), order="F"), cov_sqrtm=jnp.kron(jnp.eye(m.shape[1]), sc))
        elif isinstance(self.solver, (tornado.ek0.KroneckerEK0)):
            y = tornado.rv.MatrixNormal(mean=m, cov_sqrtm_1=jnp.eye(m.shape[1]) ,cov_sqrtm_2=sc)
        else:
            y = tornado.rv.BatchedMultivariateNormal(mean=m, cov_sqrtm=jnp.stack([sc]*m.shape[1]))

        return tornado.odesolver.ODEFilterState(
            ivp=self.ivp,
            y=y,
            t=self.ivp.t0,
            error_estimate=None,
            reference_state=None,
        )


    def time_initialize(self):
       def _run_initialize():
            self.solver.initialize(self.ivp)

        if self.jit:
            _run_initialize = jax.jit(_run_initialize)
        _run_initialize()

        elapsed_time = self.time_function(_run_initialize)
        self.result["time_initialize"] = elapsed_time
        return elapsed_time

    def time_attempt_step(self):
        def _run_attempt_step():
            """Manually do the repeated number of runs, because otherwise jax notices how the outputs are not reused anymore."""

            state = self.solver.attempt_step(
                state=self.init_state, dt=self.hyper_param_dict["dt"]
            )
            for _ in range(self.num_repetitions):
                state = self.solver.attempt_step(
                    state=state, dt=self.hyper_param_dict["dt"]
                )
            return state.y.mean

        if self.jit:
            _run_attempt_step = jax.jit(_run_attempt_step)
        _run_attempt_step()

        elapsed_time = self.time_function(_run_attempt_step)
        self.result["time_attempt_step"] = elapsed_time
        return elapsed_time

    def to_dataframe(self):
        def _aslist(arg):
            try:
                return list(arg)
            except TypeError:
                return [arg]

        results = {k: _aslist(v) for k, v in self.result.items()}
        return pd.DataFrame(
            dict(
                method=self.method,
                d=self.ode_dimension,
                nu=self.num_derivatives,
                jit=self.jit,
                **results,
            ),
        )

    @property
    def hyper_parameters(self):
        return pd.DataFrame(self.hyper_param_dict)

    def __repr__(self) -> str:
        s = f"{self.method} "
        s += "{\n"
        s += f"\td={self.ode_dimension}\n"
        s += f"\tnu={self.num_derivatives}\n"
        s += f"\tjit={self.jit}\n"
        s += f"\tresults={self.result}\n"
        return s + "}"

    def time_function(self, fun):
        # Average time, not minimum time, because we do not want to accidentally
        # run into some of JAX's lazy-execution-optimisation.
        avg_time = timeit.Timer(fun).timeit(number=1) / self.num_repetitions
        return avg_time

I think the bottom four methods (to_dataframe(), hyperparmaeters, repr, and time_function()) will be the same for every experiment runner in every script. Maybe we save a bit of code and logic if we extract some of it.

Maybe we also only create unnecessary refactoring of temporary code (are we doing an oil change in a rental car??). I don't know. I think it would make things not only more efficient to work with, but also more fun when the scripts are nice and clean. But maybe this is a waste of time...

Requirements

Should be keep track of requirements? E.g. ProbNum versions, and all sorts of other shit (jax, tqdm, matplotlib, etc.)

CI

It might be useful to have minimal CI, at least with a black- and isort-check (to keep the diffs clean).

How far can we push the dimension

How high a dimension can KroneckerEK0 and DiagonalEK1 handle?

On my laptop, I am able to push the KroneckerEK0 in the lorenz attempt step experiment to d=2048, nu=8 before the process gets killed when trying d=4096 (both for nu=3 and nu=8).

Tthe step is still super fast, but I suppose I run out of RAM -- presumably in the taylor mode init?!

Can we do more somehow?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.