Coder Social home page Coder Social logo

cyang49 / float8_experimental Goto Github PK

View Code? Open in Web Editor NEW

This project forked from pytorch-labs/float8_experimental

0.0 0.0 0.0 126 KB

This repository contains the experimental PyTorch native float8 training UX

License: BSD 3-Clause "New" or "Revised" License

Shell 1.38% Python 98.62%

float8_experimental's Introduction

float8_experimental

This is an early version of a library for accelerating training with float8 in native PyTorch according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling. torch.compile is supported out of the box. With torch.compile on, initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.

⚠️ See the feature tracker for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet.

⚠️ Backwards compatibility is not guaranteed at this point. The codebase is in active development and will change rapidly.

installation

⚠️ For now, use the latest PyTorch nightly for best results with torch.compile.

pip install .

# Optionally install editable
pip install -e .

# Optionally Install dev tooling
pip install -e ".[dev]"

User API

We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.

float8 linear with dynamic scaling

from float8_experimental.float8_linear_utils import (
    swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8DynamicLinear`
swap_linear_with_float8_linear(m, Float8DynamicLinear)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)

# optional: enable torch.compile for improved performance
m = torch.compile(m)

# train/finetune (not shown)

float8 linear with delayed scaling

from float8_experimental.float8_linear_utils import (
    swap_linear_with_float8_linear,
    sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear import Float8Linear

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, Float8Linear)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
from float8_experimental import config
config.enable_amax_init = False  # only needed for autocast + compile + FSDP +  float8 delayed
config.enable_pre_and_post_forward = False  # only needed for autocast + compile + FSDP +  float8 delayed
model = FSDP(model, use_orig_params=True)

# optional: enable torch.compile for improved performance
m = torch.compile(m)

# toy training loop
for _ in range(N_ITER):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()

    # specific to float8 with delayed scaling: separate step to sync scales/amaxes
    # in the future, this may move to a context manager
    sync_float8_amax_and_scale_history(model)

    optimizer.step()

🧭 Code Organization

  • float8_experimental/float8_linear.py
    • Float8Linear (main user facing entry point for delayed scaling)
  • float8_experimental/float8_dynamic_linear.py
    • Float8DynamicLinear (main user facing entry point for dynamic scaling)
  • float8_experimental/float8_tensor.py
    • Float8Tensor, which allows Float8Linear to abide by the x.dtype == x.grad.dtype restriction
    • ScaledMMConfig defines the semantics for matmul in the forward and backwards pass

Testing

# run single-GPU unit tests
pytest test/test_base.py

# run a single-GPU integration test on SAM
pytest test/test_sam.py

# run single-GPU compile tests
pytest test/test_compile.py
# run a two-GPU integration test on FSDP
./test/test_fsdp.sh

# run integration tests for TP/SP (outdated)
./test/test_tp.sh

# run all of these tests
./test/run_everything.sh

Benchmarking

# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
./benchmarks/bench_matmul.py

# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
# make sure to turn on torch.compile to get the best performance
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile

License

PyTorch has a BSD 3-Clause License, as found in the LICENSE file.

float8_experimental's People

Contributors

drisspg avatar vkuzo avatar y-sq avatar awgu avatar wanchaol avatar cyang49 avatar amyreese avatar jianyuh avatar pls331 avatar r-barnes avatar summerdengfb avatar facebook-github-bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.