Coder Social home page Coder Social logo

fluxml / daggerflux.jl Goto Github PK

View Code? Open in Web Editor NEW
64.0 9.0 2.0 63 KB

Distributed computation of differentiation pipelines to use multiple workers, devices, GPU, etc. since Julia wasn't fast enough already

License: Other

Julia 100.00%
julia machine-learning fluxml flux deeplearning

daggerflux.jl's Introduction

DaggerFlux.jl

This is currently an early stage integration between Dagger.jl and Flux.jl to allow for distributed computation of differentiation pipelines to use multiple workers, devices, GPUs etc. This package enables model parallelism for Flux models.

Basic Usage

To see the package in action, we would have to start julia with multiple workers.

Also make sure that the workers have access to the environment and code that is going to be run. This is typically done with the help of the exeflags keyword in addprocs. Something like addprocs(2, exeflags = "--project") is usually enough. Please ensure that the environment has access to DaggerFlux.

julia> using DaggerFlux, Dagger, Flux, Zygote

julia> @everywhere function layer(x)
         @show myid()
         x
       end

julia> ip = rand(3,3);

julia> c = Chain(layer, layer, layer, layer)
Chain(layer, layer, layer, layer)

julia> dc = DaggerChain(c)
DaggerChain(Chain(layer, layer, layer, layer))

julia> dc(ip) # notice the output is a Dagger Thunk rather than an eager evaluation
Thunk[4](layer, (Thunk[3](layer, ...),))

julia> collect(dc(ip))
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
3×3 Matrix{Float64}:
 0.813575   0.828228  0.0630336
 0.0755053  0.215495  0.64503
 0.462957   0.345485  0.83312

Notice that the model was now evaluated across multiple workers.

Flux models

This is basically the same as before, but we will demo how to differentiate through Flux models.

julia> y, back = Zygote.pullback((m,x) -> m(x), dc, ip)
(Thunk[135](layer, (Thunk[131](layer, ...),)), Zygote.var"#46#47"{typeof((#11))}(∂(#11)))

julia> collect(y)
      From worker 3:    myid() = 3
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
      From worker 2:    myid() = 2
3×3 Matrix{Float64}:
 0.813575   0.828228  0.0630336
 0.0755053  0.215495  0.64503
 0.462957   0.345485  0.83312

julia> back(one.(y))
      From worker 2:    myid() = 2
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      [...]
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
((chain = (layers = (nothing, nothing, nothing, nothing),),), [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0])

And now one can optimise over entire models!

Of course one can substitute our dummy model here with more routine models such as ResNet from Metalhead.jl. Here's a slightly simpler model for an example.

julia> m = Chain(Dense(2,2), Dense(2,2))
Chain(
  Dense(2, 2),                          # 6 parameters
  Dense(2, 2),                          # 6 parameters
)                   # Total: 4 arrays, 12 parameters, 304 bytes.

julia> dm = DaggerChain(m)
DaggerChain(Chain(Dense(2, 2), Dense(2, 2)))

julia> y, b = Zygote.pullback((m,x) -> m(x), dm, rand(Float32, 2
,2))
(Thunk[150](Dense(2, 2), (Thunk[149](Dense(2, 2), ...),)), Zygote.var"#46#47"{typeof((#13))}(∂(#13)))

julia> b(one.(y))
((chain = (layers = ((weight = Float32[1.0398567 0.45392603; 0.4867683 0.21248773], bias = Float32[1.6065784, 0.75205684], σ = nothing), (weight = Float32[-1.247205 1.2783735; -1.247205 1.278
735], bias = Float32[2.0, 2.0], σ = nothing)),),), Float32[-0.14533046 -0.14533046; -0.58934844 -0.58934844])

Contributions welcome to the GitHub repository!

daggerflux.jl's People

Contributors

dhairyalgandhi avatar github-actions[bot] avatar jpsamaroo avatar skyleaworlder avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

daggerflux.jl's Issues

The compatibility needs to be updated

I am getting this error, when I want to install the package

ERROR: Unsatisfiable requirements detected for package Dagger [d58978e5]:
 Dagger [d58978e5] log:
 ├─possible versions are: 0.6.2-0.16.1 or uninstalled
 ├─restricted to versions 0.14.4-0.14 by DaggerFlux [30118311], leaving only versions 0.14.4
 │ └─DaggerFlux [30118311] log:
 │   ├─possible versions are: 0.1.0 or uninstalled
 │   └─DaggerFlux [30118311] is fixed to version 0.1.0
 └─restricted by compatibility requirements with DaggerGPU [68e73e28] to versions: [0.8.0, 0.10.0-0.10.2, 0.13.3-0.13.7] — no versions left
   └─DaggerGPU [68e73e28] log:
     ├─possible versions are: 0.1.0-0.1.3 or uninstalled
     └─restricted to versions 0.1 by DaggerFlux [30118311], leaving only versions 0.1.0-0.1.3
       └─DaggerFlux [30118311] log: see above

Don't assume the model is on a CUDA device

Currently, DaggerChain communicates to Dagger that the wrapped model is located on a CUDA GPU, which is not necessarily true (and shouldn't be a requirement anyway). We should provide functions which can move the model to the GPU and communicate the correct location to Dagger, and/or auto-detect where a model currently resides.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.