Coder Social home page Coder Social logo

FP6 dtype! about ao HOT 7 OPEN

NicolasMejiaPetit avatar NicolasMejiaPetit commented on June 2, 2024 44
FP6 dtype!

from ao.

Comments (7)

vkuzo avatar vkuzo commented on June 2, 2024 3

This is great and the inference e2e integration like a good candidate for addition to https://github.com/pytorch/ao . Let us know if you are interested in contributing!

As far fp6 dtype in PT core, check out https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 for the current thinking on adding new dtypes. We do expect fp6 to get silicon support in the future so it would be a good candidate to add when that silicon support is closer. We don't actually need an fp6 dtype in core to enable w6a16 as implemented in the code linked to this issue.

from ao.

msaroufim avatar msaroufim commented on June 2, 2024 1

Keeping this open because we still need to do the subclass work and the end to end integration

from ao.

gau-nernst avatar gau-nernst commented on June 2, 2024 1

Tracker:

  • FP16 act - FP6 weight linear CUDA kernel (#223)
  • Improve FP32/FP16/BF16 <-> FP6 conversion (with CUDA support) (#248)
  • Improve weight splitting (with CUDA support) (#279)
  • User-friendly API (either Tensor subclass or FP6Linear module) (#279 #283)
  • End2end benchmark
  • Remove unnecessary code (e.g. weight_quant.cu, weight_prepacking.cpp)

from ao.

gau-nernst avatar gau-nernst commented on June 2, 2024 1

Just to update people here on the progress. We have added a user API for FP6-LLM

from torchao.quantization.fp6_llm import convert_fp6_llm

convert_fp6_llm(model)  # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear

Everything should work (in eager mode). Some local end2end testing by me and @Iron-Bound show that it works as expected. We will probably close this issue once we have an LLM eval in this repo for uniform evaluation across quantization methods (there is also a small difference in how we handle FP16->FP6 quantization compared to the released code, so I want to make sure this difference is not significant).

Some known limitations:

  • The kernel is for FP16 activation - FP6_E3M2 weight. If your model is BF16, it should still work, but you will spend some small overhead converting BF16 <-> FP16. (perhaps we can implement a BF16 version in a future PR? not sure how much work is required - only need to change weight dequant logic and call the correct tensor core instruction?)
  • When tested with gpt-fast, torch.compile does not work for an FP6-LLM end2end model (it does work for small test cases though - we have CI for that). Need to debug this.
    • UPDATE: adding torch._inductor.config.triton.cudagraph_trees = False fixes the issue.

Data from gpt-fast for meta-llama/Llama-2-7b-chat-hf on 4080

name tokens/s
BF16 baseline (w/ compile) 49.15
FP6-LLM (no compile) 82.55
int8 (w/ compile) 91.12

hellaswag eval (from https://github.com/EleutherAI/lm-evaluation-harness) for meta-llama/Llama-2-7b-chat-hf (credits to @Iron-Bound)

name acc_norm
baseline 75.50
FP6-LLM 75.36

from ao.

JiHa-Kim avatar JiHa-Kim commented on June 2, 2024

Seconded!

from ao.

cpuhrsch avatar cpuhrsch commented on June 2, 2024

Just a nit on "User-friendly API (either Tensor subclass or FP6Linear module)". You can implement an FP6Linear module using the a Tensor subclass based fp6 dtype. Just call self.weight = nn.Parameter(to_fp6(self.weight)) within the __init__ of your nn.Linear replacement. The FP6Linear module then is one way of injecting that code into the model. It seems like a very popular way of doing that, so it's reasonable to provide as a primitive. Pretty much I'm only pointing out that you don't duplicate work by doing both :) You can then also make it easier for people to add coverage as shown in our toy example

@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])
def gelu(func, *args, **kwargs):
# The torch dispatch convention is to pass all args and kwargs via the
# args input.
# args[0] here corresponds to the original *args
# args[1] here corresponds to the original *kwargs
# We're getting the first argument of the original args
inp = args[0][0]
# There's a way very inefficient way to implement it
return to_nf4(torch.nn.functional.gelu(inp.to(torch.float32)), inp.block_size, inp.scaler_block_size)
print(f"gelu(a): {torch.nn.functional.gelu(a)}")
print(f"gelu(a_nf4): {torch.nn.functional.gelu(a_nf4)}")

from ao.

gau-nernst avatar gau-nernst commented on June 2, 2024

Thank you for your feedback. They are just a few suggested ways as discussed with @msaroufim, we haven't decided on what is the final API for FP6 yet.

Of course if we have FP6 subclass, we don't need FP6Linear anymore. But implementing subclass is harder, and almost all ops, except F.linear, do not make sense for FP6. This is because in FP6-LLM, the weight is split and re-arranged in a certain way to optimize global memory access for tensor cores.

I tried implementing FP6 subclass in #223 (and removed it in the end). Even implementing dispatch for aten.linear feels finicky because it seems pytorch will dispatch aten.mm (or aten.addmm) instead, so I have to store the transposed flag, set and check it correctly before calling the FP6-linear kernel (the CUDA kernel only works with A @ W.T i.e. Linear layer). To support other ops, we would need to (1) re-arrange the weight in natural order and (2) dequantize to FP32/FP16/BF16 (and reverse it back to FP6). It would be too expensive.

So I think implementing a custom FP6Linear layer would be easier, since we don't need to guarantee anything about the weight i.e. The weight itself is an internal implementation detail.

Just some of my thoughts when working on this. Once #248 is merged, I will work on adapting weight splitting logic (currently it's a CPU-only C++ extension). Note that the original code does not have weight un-splitting logic.

from ao.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.