Coder Social home page Coder Social logo

Functional AD about flux.jl HOT 7 CLOSED

fluxml avatar fluxml commented on August 17, 2024
Functional AD

from flux.jl.

Comments (7)

MikeInnes avatar MikeInnes commented on August 17, 2024 1

I guess that would have to be grad(loss, (W, b, x), x, y). Although that really makes it clear that loss should just be a zero-arg function, like grad(() -> loss(x, y), [W, b, x]).

from flux.jl.

dfdx avatar dfdx commented on August 17, 2024

W and b are treated as implicit arguments to the function; this is nice in that it's essentially the ideal functional interface but without the mess of hundreds of explicit arguments.

Does this mean that a user will need to make W and b global and loss to be defined in the same context? This doesn't sound very flexible, to be honest. Recently I've been playing around (e.g. in VariationalAE.jl) with models as mutable structs. For your case it would look something like:

m = Linear(W, b)
loss(m::Linear, x, y) = mse(m.W * x .+ m.b, y)
dm = grad(loss, m, x, y)

where dm is another instance of Linear holding derivatives, i.e.:

dW = dm.W
db = dm.b

from flux.jl.

MikeInnes avatar MikeInnes commented on August 17, 2024

Does this mean that a user will need to make W and b global and loss to be defined in the same context?

Er, no? For the most part I'm not expecting any thing else to look different; so the MNIST example would stay exactly the same. It's really no different to the current TrackedArray approach in that sense, just without the hacky overloading.

Your API is something we discussed as it's closer to what Knet currently has. At a minimum it only scales up well if you allow the structure to define the forward pass (e.g. via call overloading). Even then it imposes a bigger burden on user-defined types and small models, and it's harder to figure out how it plays when you get to really complex models (as one example, higher-order models that take another model as input).

from flux.jl.

dfdx avatar dfdx commented on August 17, 2024

Ah, I re-read the MNIST example. Do I understand correctly that loss() is actually a closure bound to an instance of object m::Chain? In this case my comment is indeed irrelevant.

from flux.jl.

MikeInnes avatar MikeInnes commented on August 17, 2024

Essentially yes, it's not actually a closure in this case because it's global, but it could be. In the docs there are some examples of closing over parameters, and I expect those to work with Cassette as well.

from flux.jl.

baggepinnen avatar baggepinnen commented on August 17, 2024

How would the user go about taking the gradient of a model output with respect to a non-parmameter like the input? This is common in creating adverserial examples, linearizing dynamical models etc.

from flux.jl.

MikeInnes avatar MikeInnes commented on August 17, 2024

A year on we can do some much cooler things here. Closing in favour of #628.

from flux.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.