pnkraemer / million-dimension-prob-ode-solver-experiments Goto Github PK
View Code? Open in Web Editor NEWLicense: MIT License
License: MIT License
The different lorenz versions do not coincide perfectly.
There is also a bad default initial value.
In my opinion, the following changes would make experiment 1 paper ready:
For now let's try to push every solver as far as possible -- this will hopefully change soon (as in that the solvers get more efficient).
Statement: ODE solves become faster.
This would be a classic work-precision diagram.
For the beginning, answer this question with lorenz.
Then move on to pleiades perhaps?
There are a few reasons why we think the plots yielded by experiment 2 look not entirely, as we expect:
Just wanted to raise the question since it can still be changed, but in general I think I'm fine with this repo name.
The only alternative that comes to my mind is "ode-filter" instead of "ode-solver"?
But we can also keep it as it is, again I just wanted to make sure that we consciously decide on a public repo name.
At the moment, the results are stored in results//results.{csv, pdf}. There is also a directory data/ and a directory figures/, which stem from the initial scaffolding and must be removed to clarify that there is one way of doing things.
#14 introduces a benchmark-option that just-in-time-compiles the attempt_step functions.
This is for two reasons:
Since this makes the experiment really slow (compilation takes a while), it is turned off by default and should only be turned on when desired.
If we notice that we never turn it on, let's remove it again!
Do we only lose UQ, or (more likely) stability.Likely depends onhowthe EK1 is made faster: diago-nal jacobians will screw us, full jacobians with trunca-tions might work. Check on stiff van-der-pol? Maybe this would be the only section with a chi2 plot.
Goal: compare dimension of the problem (x-axis) to runtime (average per step?) on the y-axis. Fixed step. Perhaps different orders. Show that (a) the improvements over the vanilla implementation are significant; and (b) that the complexities from the propositions earlier are attained
From a research perspective, show RK45 and radau from scipy as well (where are we now?) Perhaps we need to show overall time, not time per step as well?
Full fledged simulation of a high-dimensional PDE (adaptive steps, initialisation, and all that jazz). Overall runtime, make it look impressive.
Statement: ODE solves become \emph{possible}.
I have a feeling that perhaps there is a bit of incoming code duplication in the experiments.
If we take for example the experiment runner in 1_sterilised...
class SterilisedExperiment:
def __init__(
self,
method,
num_derivatives,
ode_dimension,
hyper_param_dict,
jit,
num_repetitions,
) -> None:
self.hyper_param_dict = hyper_param_dict
self.method = method
self.num_derivatives = num_derivatives
self.ode_dimension = ode_dimension
self.ivp = problems.lorenz96_jax(
params=(ode_dimension, hyper_param_dict["forcing"]),
t0=hyper_param_dict["t0"],
tmax=hyper_param_dict["tmax"],
y0=jnp.arange(ode_dimension) * 1.0,
)
self.solver = tornado.ivpsolve._SOLVER_REGISTRY[method](
ode_dimension=ode_dimension,
steprule=tornado.step.ConstantSteps(hyper_param_dict["dt"]),
num_derivatives=num_derivatives,
)
self.init_state = self.initial_state()
self.result = dict()
# Whether the step function is jitted before timing
self.jit = jit
# How often each experiment is run
self.num_repetitions = num_repetitions
def initial_state(self):
m0, sc0 = tornado.init.stack_initial_state_jac(
f=self.ivp.f, df=self.ivp.df, y0=self.ivp.y0, t0=self.ivp.t0, num_derivatives=self.num_derivatives
)
num_steps = self.num_derivatives + 1
ts, ys = tornado.init.rk_data(f=self.ivp.f, t0=self.ivp.t0, dt=0.01, num_steps=num_steps, y0=self.ivp.y0, method="RK45")
m, sc = tornado.init.rk_init_improve(
m=m0,
sc=sc0,
t0=self.ivp.t0,
ts=ts,
ys=ys,
)
if isinstance(self.solver, (tornado.ek0.ReferenceEK0, tornado.ek1.ReferenceEK1)):
y = tornado.rv.MultivariateNormal(mean=m.reshape((-1,), order="F"), cov_sqrtm=jnp.kron(jnp.eye(m.shape[1]), sc))
elif isinstance(self.solver, (tornado.ek0.KroneckerEK0)):
y = tornado.rv.MatrixNormal(mean=m, cov_sqrtm_1=jnp.eye(m.shape[1]) ,cov_sqrtm_2=sc)
else:
y = tornado.rv.BatchedMultivariateNormal(mean=m, cov_sqrtm=jnp.stack([sc]*m.shape[1]))
return tornado.odesolver.ODEFilterState(
ivp=self.ivp,
y=y,
t=self.ivp.t0,
error_estimate=None,
reference_state=None,
)
def time_initialize(self):
def _run_initialize():
self.solver.initialize(self.ivp)
if self.jit:
_run_initialize = jax.jit(_run_initialize)
_run_initialize()
elapsed_time = self.time_function(_run_initialize)
self.result["time_initialize"] = elapsed_time
return elapsed_time
def time_attempt_step(self):
def _run_attempt_step():
"""Manually do the repeated number of runs, because otherwise jax notices how the outputs are not reused anymore."""
state = self.solver.attempt_step(
state=self.init_state, dt=self.hyper_param_dict["dt"]
)
for _ in range(self.num_repetitions):
state = self.solver.attempt_step(
state=state, dt=self.hyper_param_dict["dt"]
)
return state.y.mean
if self.jit:
_run_attempt_step = jax.jit(_run_attempt_step)
_run_attempt_step()
elapsed_time = self.time_function(_run_attempt_step)
self.result["time_attempt_step"] = elapsed_time
return elapsed_time
def to_dataframe(self):
def _aslist(arg):
try:
return list(arg)
except TypeError:
return [arg]
results = {k: _aslist(v) for k, v in self.result.items()}
return pd.DataFrame(
dict(
method=self.method,
d=self.ode_dimension,
nu=self.num_derivatives,
jit=self.jit,
**results,
),
)
@property
def hyper_parameters(self):
return pd.DataFrame(self.hyper_param_dict)
def __repr__(self) -> str:
s = f"{self.method} "
s += "{\n"
s += f"\td={self.ode_dimension}\n"
s += f"\tnu={self.num_derivatives}\n"
s += f"\tjit={self.jit}\n"
s += f"\tresults={self.result}\n"
return s + "}"
def time_function(self, fun):
# Average time, not minimum time, because we do not want to accidentally
# run into some of JAX's lazy-execution-optimisation.
avg_time = timeit.Timer(fun).timeit(number=1) / self.num_repetitions
return avg_time
I think the bottom four methods (to_dataframe(), hyperparmaeters, repr, and time_function()) will be the same for every experiment runner in every script. Maybe we save a bit of code and logic if we extract some of it.
Maybe we also only create unnecessary refactoring of temporary code (are we doing an oil change in a rental car??). I don't know. I think it would make things not only more efficient to work with, but also more fun when the scripts are nice and clean. But maybe this is a waste of time...
Should be keep track of requirements? E.g. ProbNum versions, and all sorts of other shit (jax, tqdm, matplotlib, etc.)
It might be useful to have minimal CI, at least with a black- and isort-check (to keep the diffs clean).
How high a dimension can KroneckerEK0 and DiagonalEK1 handle?
On my laptop, I am able to push the KroneckerEK0 in the lorenz attempt step experiment to d=2048, nu=8 before the process gets killed when trying d=4096 (both for nu=3 and nu=8).
Tthe step is still super fast, but I suppose I run out of RAM -- presumably in the taylor mode init?!
Can we do more somehow?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.