Coder Social home page Coder Social logo

conflux.jl's Introduction

Conflux

Latest Release MIT license Documentation

Conflux.jl is a toolkit designed to enable data parallelism for Flux.jl models by simplifying the process of replicating them across multiple GPUs on a single node, and by leveraging NCCL.jl for efficient inter-GPU communication. This package aims to provide a straightforward and intuitive interface for multi-GPU training, requiring minimal changes to existing code and training loops.

Features

  • Easy replication of objects across multiple GPUs with the replicate function
  • Efficient synchronization of models and averaging of gradients with the allreduce! function, which takes an operator (e.g. +, *, avg) and a set of replicas, and reduces all their parameters with the given operator, leaving the replicas identical.
  • A withdevices function that allows you to run code on each device asynchronously.

See the documentation for more details, examples, and important caveats.

Installation

The package can be installed with the Julia package manager. From the Julia REPL, type ] to enter the Pkg REPL mode and run:

pkg> add https://github.com/MurrellGroup/Conflux.jl#main

Example usage

# Specify the default devices to use
ENV["CUDA_VISIBLE_DEVICES"] = "0,1"

using Conflux

using Flux, Optimisers

model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh), Dense(256 => 1))

# This will use the available devices. If you want to use a specific device, you can pass them in a second argument.
models = replicate(model)

opt = Optimisers.Adam(0.001f0)

# Instantiate the optimiser states on each device
states = Conflux.withdevices() do (i, device)
    Optimisers.setup(opt, model) |> device
end

# A single batch, stored on CPU. Could use a more sophisticated mechanism to distribute multiple batches.
X = rand(1, 16)
Y = X .^ 2

loss(y, Y) = sum(abs2, y .- Y)

losses = []
for epoch in 1:10
    # Get the gradients for each batch on each device
    ∇models = Conflux.withdevices() do (i, device)
        x, y = device(X), device(Y)
        # The second return value is a tuple because `Flux.withgradient` takes `args...`, and the model is the first argument.
        l, (∇model,) = Flux.withgradient(m -> loss(m(x), y), models[i])
        push!(losses, l)
        ∇model
    end

    # Average the gradients across devices
    allreduce!(avg, ∇models...)

    # Update the models on each device
    Conflux.withdevices() do (i, device)
        Optimisers.update!(states[i], models[i], ∇models[i])
    end

    # Optionally synchronize the models and optimiser states, in case the parameters diverge
    #allreduce!(avg, models...)
    #allreduce!(avg, states...)
end

conflux.jl's People

Contributors

anton083 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.