Coder Social home page Coder Social logo

Neural Network Potentials about molly.jl HOT 7 OPEN

juliamolsim avatar juliamolsim commented on May 28, 2024
Neural Network Potentials

from molly.jl.

Comments (7)

jgreener64 avatar jgreener64 commented on May 28, 2024 1

Some docs and examples for the experimental differentiable branch are up. Specific interactions now work.

https://juliamolsim.github.io/Molly.jl/latest/differentiable.html

from molly.jl.

jgreener64 avatar jgreener64 commented on May 28, 2024

Would it please be possible to add an example of neural network as the simulator?

I intend to put some examples of making custom potentials in the docs soon. Having a call to a neural network would work just fine.

However training a neural network by autodifferentiating back through the simulation is more problematic; I was playing around with it last night and running into problems because currently we mutate a lot of arrays.

Hopefully mutation support in Zygote will improve soon and then it will be possible. I did also try ReverseDiff but ran into some typing problems.

Long term it is a priority as it was one of the motivations of making this package.

Also, how can we make a nice Gif?

Now the link to the docs works you can find that information there.

from molly.jl.

bionicles avatar bionicles commented on May 28, 2024

They're working on a big chain rule update for Zygote which may indeed help. Thanks for fixing the docs

from molly.jl.

jgreener64 avatar jgreener64 commented on May 28, 2024

Yeah there is lots of exciting development going on.

I did manage to refactor the code to remove mutation and get autodiff working with Zygote, but there was a performance hit. I'll keep looking at that.

from molly.jl.

ChrisRackauckas avatar ChrisRackauckas commented on May 28, 2024

It should be possible to setup DiffEq for timesteping and then use the adjoint method for the differentiation.

from molly.jl.

jgreener64 avatar jgreener64 commented on May 28, 2024

I have a prototype working on the differentiable branch. The following code dump repeatedly runs a simulation of 50 atoms for 500 steps and optimises the Lennard-Jones σ value to match a desired mean minimum separation of atoms at the end of the simulation. It uses a neighbour list and PBCs, and gets a low loss in 25 epochs.

EDIT: this is out of date now, see link to docs below.

using Molly
using Zygote

function meanminseparation(final_coords, box_size)
    n_atoms = length(final_coords)
    sum_dists = 0.0
    for i in 1:n_atoms
        min_dist = 100.0
        for j in 1:n_atoms
            i == j && continue
            dist = sqrt(square_distance(i, j, final_coords, box_size))
            min_dist = min(dist, min_dist)
        end
        sum_dists += min_dist
    end
    return sum_dists / n_atoms
end

dist_true = 1.0
σtrue = dist_true / (2 ^ (1 / 6))

n_atoms = 50
mass = 10.0
box_size = 5.0
coords = [box_size .* rand(SVector{3}) for i in 1:n_atoms]
temperature = 0.1
velocities = [velocity(mass, temperature) .* 0.0 for i in 1:n_atoms]
general_inters = Dict{String, GeneralInteraction}("LJ" => LennardJones(true))
neighbour_finder = DistanceNeighbourFinder(trues(n_atoms, n_atoms), 20, 2.0)

function loss(σ)
    s = Simulation{typeof(coords)}(
        VelocityVerlet(),
        [Atom("", "", 0, "", 0.0, mass, σ, 0.2) for i in 1:n_atoms],
        Dict{String, Vector{SpecificInteraction}}(),
        general_inters,
        coords,
        velocities,
        temperature,
        box_size,
        Tuple{Int, Int}[],
        neighbour_finder,
        NoThermostat(),
        Logger[],
        0.05,
        10,
        [0]
    )
    mms_start = meanminseparation(coords, box_size)
    c = simulate!(s, 500, parallel=false)
    mms_end = meanminseparation(c, box_size)
    l = abs(mms_end - dist_true)
    println("σ                      ", round(σ, digits=3))
    println("mean min sep expected  ", round* (2 ^ (1 / 6)), digits=3))
    println("mean min sep start     ", round(mms_start, digits=3))
    println("mean min sep end       ", round(mms_end, digits=3))
    println("loss                   ", round(l, digits=3))
    return l
end

grad = gradient(loss, σtrue)[1]

# Simple training loop
function train()
    σlearn = 0.8 / (2 ^ (1 / 6))
    for epoch_n in 1:25
        println("Epoch ", epoch_n)
        grad = gradient(loss, σlearn)[1]
        σlearn -= grad * 1e-2
        println()
    end
    return σlearn
end

σlearn = train()

from molly.jl.

jgreener64 avatar jgreener64 commented on May 28, 2024

The integration steps were okay to implement - the harder bit was changing all the devectorised code in the force calculation that made it fast before. I have found there is a performance/memory hit in making it Zygote-friendly.

The next step is to work out how to code specific interactions in this scheme, e.g. a covalent bond between two specific atoms. I was hoping I could use a sparse matrix constructor to turn the results of a broadcast over bonds into a standard dense vector, but I haven't got that working with Zygote yet.

Also I'd like to look at using forward-mode autodiff as the memory requirements don't scale so badly with the length of the simulation.

from molly.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.