Coder Social home page Coder Social logo

neuralode4j's People

Contributors

drchainsaw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

mannykayy

neuralode4j's Issues

Make OdeVertex serializable

Only non-trivial part is hopefully how to serialize wrapped apache commons solvers. Should be possible to remove support for them if #1 works out.

High memory consumption with ODE net

Memory consumption when using ODE net in the MNIST example becomes quite high after a while and it seems to scale with the number of steps taken by the solver.

Needs further investigation...

Model performance is not same as reference implementation

The performance of both the resnet reference model (which does not use any of the added stuff from this repo) and the odenet model is much worse than the performance from the reference repo.

After completing 200 epochs with the resnet model, the validation accuracy was still only about 92%.

As a comparison, the same model from the reference repo reaches accuracy of about 98% after the first epoch (vs about 88% in this repo). This holds true even after replacing layers which are not implemented in dl4j with ones which are (e.g replace group norm with batch norm).

Add ODE solvers with ND4J

ODENet spends a lot of time copying INDArrays to double[] and vice versa as apache.commons.math is . Writing solvers which use INDArrays will at least mitigate this issue. Hopefully the actual solvers will be faster as well.

ODEnet produces NaN when evaluating

This happens when evaluating after having completed a training epoch. Does not seem to happen if number of training examples is reduced which makes it pretty annoying to debug. Smallest number of examples when behaviour is reproduced was 80 iterations of 128 examples.

Furthermore, if sysouts are added to print variables which are suspects of being the source of NaNs, the problem seems to go away, at least with the above number of examples. This holds for more or less any variable, so it did not work as a method to narrow down the error further.

Due to this, current plan is to debug this with low effort during off hours and focus on #1.

Tested:

  • Remove all special workspace handling in OdeVertex and just use the provided LayerWorkspaceMgr
  • Various permutations of scope out of workspaces
  • Detaching activations, inputs and epsilons.
  • Running with CPU backend for a few number of iterations to test if a workspace error is caught there which is not caught with the CUDA backend (there was actually an issue with batchnorm putting state in workspace for ArrayType.INPUT, but fixing it didn't solve this issue)
  • Run without asynch prefetch
  • Implement #1 and hope problem just goes away
  • Running whole model with workspace mode NONE for training and inference (don't know why I didn't think of this first)
  • Serializing the model before eval and see if same issue persists with a deserialized graph. Probably requires #3
  • Running validation with training = true by just always giving "true" in constructor. This prevents the issue from surfacing and accuracy seems sensible. Drawback is that batchnorm is most likely not as effective as it should be. See post below about root cause.

Not tested:

  • Try another data set
  • ...

Spiral demo does not produce same quality results as original implementation

The results from the spiral demo experiment seem to be of lower quality than what the reference implementation yields.

Simple experiments show that the whole experiement (both this implementation and the reference one) is very sensitive to hyper parameters (eg. layer sizes and weight initialization) so perhaps there is some hyper param which is not aligned.

Using l2 weight decay yields acceptable looking spirals, but this does not seem to be used in the original implementation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.