drchainsaw / neuralode4j Goto Github PK
View Code? Open in Web Editor NEWImplementation of neural ordinary differential equations for java
License: MIT License
Implementation of neural ordinary differential equations for java
License: MIT License
Lots of cool examples in https://github.com/rtqichen/ffjord/blob/master/lib/layers/diffeq_layers/basic.py
Only non-trivial part is hopefully how to serialize wrapped apache commons solvers. Should be possible to remove support for them if #1 works out.
Memory consumption when using ODE net in the MNIST example becomes quite high after a while and it seems to scale with the number of steps taken by the solver.
Needs further investigation...
They end up in disk root dir right now...
This seems to be what is used in the paper
The performance of both the resnet reference model (which does not use any of the added stuff from this repo) and the odenet model is much worse than the performance from the reference repo.
After completing 200 epochs with the resnet model, the validation accuracy was still only about 92%.
As a comparison, the same model from the reference repo reaches accuracy of about 98% after the first epoch (vs about 88% in this repo). This holds true even after replacing layers which are not implemented in dl4j with ones which are (e.g replace group norm with batch norm).
ODENet spends a lot of time copying INDArrays to double[] and vice versa as apache.commons.math is . Writing solvers which use INDArrays will at least mitigate this issue. Hopefully the actual solvers will be faster as well.
This happens when evaluating after having completed a training epoch. Does not seem to happen if number of training examples is reduced which makes it pretty annoying to debug. Smallest number of examples when behaviour is reproduced was 80 iterations of 128 examples.
Furthermore, if sysouts are added to print variables which are suspects of being the source of NaNs, the problem seems to go away, at least with the above number of examples. This holds for more or less any variable, so it did not work as a method to narrow down the error further.
Due to this, current plan is to debug this with low effort during off hours and focus on #1.
Tested:
Not tested:
The results from the spiral demo experiment seem to be of lower quality than what the reference implementation yields.
Simple experiments show that the whole experiement (both this implementation and the reference one) is very sensitive to hyper parameters (eg. layer sizes and weight initialization) so perhaps there is some hyper param which is not aligned.
Using l2 weight decay yields acceptable looking spirals, but this does not seem to be used in the original implementation.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.