Coder Social home page Coder Social logo

ao's Introduction

torchao: PyTorch Architecture Optimization

This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue

Introduction

torchao is a PyTorch library for quantization and sparsity.

Get Started

Installation

torchao makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.

Stable Release

pip install torchao

Nightly Release

pip install torchao-nightly

From source

git clone https://github.com/pytorch/ao
cd ao
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install .

Quantization

import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)

Sparsity

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier

# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)

# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                 sparse_block_shape=(1,4),
                                 zeros_per_block=2)

# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights

# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

To learn more try out our APIs, you can check out API examples in

Supported Features

  1. Quantization algorithms
  2. Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
  3. Support for lower precision dtypes such as
  4. Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees

Our Goals

  • Composability with torch.compile: We rely heavily on torch.compile to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels
  • Composability with FSDP: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently.
  • Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
  • Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
  • Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices

Integrations

torchao has been integrated with other libraries including

  • torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
  • Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
  • HQQ leverages our int4mm kernel for low latency inference

Success stories

Our kernels have been used to achieve SOTA inference performance on

License

torchao is released under the BSD 3 license.

ao's People

Contributors

jerryzh168 avatar cpuhrsch avatar hdcharles avatar msaroufim avatar andrewor14 avatar svekars avatar supriyar avatar jcaip avatar jeromeku avatar drisspg avatar rohan-varma avatar weifengpy avatar aakashapoorv avatar janeyx99 avatar jokeren avatar manuelcandales avatar larryliu0820 avatar gau-nernst avatar xia-weiwen avatar dependabot[bot] avatar usingtcnower avatar leslie-fang-intel avatar

Stargazers

Sachin Chanchani avatar Yunchong Gan avatar  avatar Huy Do avatar Jong Yoon Shin avatar Yinghui Wang avatar  avatar Yixiang Gao avatar  avatar Vinh Tran avatar Jihoon Lee avatar Vik Paruchuri avatar Eternal Reclaimer avatar Patrick avatar Jsson Xia avatar  avatar Wolfie avatar  avatar Raja Biswas avatar  avatar Nathan Raw avatar Akihiro Nitta avatar  avatar Samuel Rincé avatar EnanaShinonome avatar  avatar Qianyue He avatar Realcat avatar Shitty Girl avatar tensorboy avatar Jinyu Bai avatar RkΩs avatar Jiangning Zhang avatar Xuanyi Dong avatar  avatar Alexey Golyshev avatar  avatar Tuan Vu avatar Sungkyun Kim avatar Vikramjeet Singh avatar Luis Guerra avatar Allen Guo avatar  avatar  avatar Miko avatar Jack Turner avatar Atakan Okan avatar FengWen avatar  avatar Justin Chu avatar Kenny Falkær Olsen avatar Justin avatar Sudarshan Nambiar avatar GAO WEI avatar zhifeng avatar smellslikeml avatar Yassine avatar  avatar Glenn 'devalias' Grant avatar Li Xiang avatar Maxime Rousseau avatar Hong Jia avatar Chris (Tu) NGUYEN avatar Hugo Sousa avatar Akash avatar  avatar gradetwo avatar Sasi Kiran Malladi avatar Chenguang Zhu avatar Tristan Webb avatar Douwe den Blanken avatar Motoki Wu avatar yuanqian_zhao avatar Jaesun Park avatar  avatar TonyZhao avatar Matthew Douglas avatar Ali Khan avatar Ashley Alex Jacob avatar Zhang Zhuocheng avatar Devansh Agarwal avatar zhou fan avatar Siwei Cui avatar Faycel Kouteib avatar Vincent Moens avatar Fangkai Jiao avatar Thomas PDM avatar muhtasham avatar Edd avatar Zhengxu Chen avatar  avatar  avatar Aniket Maurya avatar Andrew Carr avatar Mike Lasby avatar Jeff Hammerbacher avatar Gu Wei avatar Tim Chard avatar Dongwoo Im avatar Akash Sonowal avatar

Watchers

Digant Desai avatar Mike avatar  avatar Shuo Yuan avatar  avatar  avatar Yushu Gao avatar  avatar Richard Zou avatar Polisetty V R K Jyothendra Varma avatar ameynaik avatar  avatar  avatar Titus avatar Kimish Patel avatar  avatar

ao's Issues

[RFC] More general affine quantization primitives

PR is here, please feel free to comment in PR directly: #159

Context

Currently there are many q/dq functions in torchao and pytorch, they mainly differ in the following dimensions:

  • dtype/bitwidth + quant_min/quant_max: e.g. torch.uint8 with quant_min=0 and quant_max = 255
  • symmetric/asymmetric quantization
  • granularity: per_tensor, per_channel, per_channel_group
  • dtype for scales and zero_points

Ideally, I think we should unify them, it might complicate the operator pattern that’s used by backends like xnnpack, but the code sharing and simplification of the representation it brings will be beneficial in the long term.

We defined three functions: choose_qparams_affine_per_block, quantize_affine_per_block, dequantize_affine_per_block, please checkout the docstrings of these functions in the PR for the definitions

Some Questions

  • for input and scale/zero_point, what do we do when they have different dtypes, e.g. when input is fp16, scales and zero_points are fp32? do we always convert to fp32 and then do the computation?
  • Concerns about using torch.Tensor for per_tensor quantization instead of float/int numbers?
  • It may run slower, is there any concerns on perf?
  • Other ways to choose qparams apart from symmetric and asymmetric?
  • clampping for quant_min/quant_max, should we include this in the quantize op or leave this out?
  • I'm also thinking of API for end users, I think we could provide a util function to get the block size, e.g. get_block_size(input, {"quant_type": "per_channel_group", "group_size": 32, "axis": -1})

FP6 dtype!

🚀 The feature, motivation and pitch

https://arxiv.org/abs/2401.14112

I think you guys are really going to like this.
The deepspeed developers introduce FP6 datatype on cards without fp8 support, while maintaining full tensor core suppourt using a kernel they created called tc-fpX. Tests were done on a a100! And they achieved 1.69x-2.65x inference performance! And I assume this can be transferred over to training (with the exception of possibly the KV cache, and embedding module). This is really exiting, this will breathe new life into the rapidly aging a100 due to the introduction of the h100’s fp8.

It was merged into deepspeed in this commit:
microsoft/DeepSpeed@ccfdb84

Getting this pushed into the Pytorch as a dtype, that would be a major win. These are the benefits FP6 provides:
IMG_4696

Alternatives

These kernels shouldn’t be limited by only the a100, they theoretically could work on any card with uint8_t and fp16 support. Provided these kernels were only written for a100 so without modification it might only work on ampere cards.

Additional context

The tc-FPx kernel essentially takes 4fp16 values, quantizes them to fp6 with some place holders. Then they get pushed into an array built of 3x Uint8_t. As shown here:
IMG_4686
IMG_4688
IMG_4689

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

[BUG] No module named 'expecttest' when import `torchao`

To reproduce, install torchao from main, then import torchao

pip install git+https://github.com/pytorch/ao
python -c "import torchao"

Error

ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import torchao

File ~/code/ao/torchao/__init__.py:8
      6 from . import dtypes
      7 import torch
----> 8 from torch.testing._internal.common_utils import IS_FBCODE
      9 if not IS_FBCODE:
     10     from . import _C

File ~/miniconda3/envs/dev2.3/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py:62
     47 from typing import (
     48     Any,
     49     Callable,
   (...)
     58     Union,
     59 )
     60 from unittest.mock import MagicMock
---> 62 import expecttest
     63 import numpy as np
     65 import __main__  # type: ignore[import]

ModuleNotFoundError: No module named 'expecttest'

The problem is self-explanatory. 2 possible solutions:

  • Add expecttest to requirements.txt
  • Add a try-except clause

[Tracker] WIP Features for torchao v0.2

New Features

Better Engineering

Repo Health

Project implicitly depends on torch nightly

Traceback (most recent call last):
  File "C:\code\foo\scripts\quantize.py", line 4, in <module>
    from torchao.quantization.smoothquant import (
  File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\__init__.py", line 7, in <module>
    from .smoothquant import *  # noqa: F403
  File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\smoothquant.py", line 17, in <module>
    import torchao.quantization.quant_api as quant_api
  File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\quant_api.py", line 18, in <module>
    from .subclass import (
  File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\subclass.py", line 13, in <module>
    from torch.utils._python_dispatch import return_and_correct_aliasing
ImportError: cannot import name 'return_and_correct_aliasing' from 'torch.utils._python_dispatch' (C:\code\py-envs\foo\lib\site-packages\torch\utils\_python_dispatch.py)

This project seems to rely on torch nightly, which exports return_and_correct_aliasing. It might be worthwhile to document this. I suppose one could argue it's obvious enough from this being an experimental repo, but it was surprising to me.

Nice work team, I'm looking forward to using this package.

[Tracker] WIP features for torchao 0.3

Focus - benchmarking, documentation, tutorials, prototype to beta

Spillover from 0.2.0

Benchmarking

  • Setup model level benchmark for accuracy and performance in torchbench for single quantization API, so that we can start deprecating quant primitives and quant apis after making sure no regressions (@HDCharles)
  • Benchmarks for auto quant on pytorch benchmark's inference quant pane (@HDCharles)
  • Int8 + 2:4 results on segment-anything-fast @jcaip

Documentation

Tutorials

  • Tutorial for affine quantization dtype and unified quant primitives - Found lots of subtle differences, especially w.r.t. preserving zeros and tinygemm (@jerryzh168)

Core

  • QAT workflow (@andrewor14)
  • Deprecating quant primitives (@jerryzh168)
  • Deprecating quant APIs (@jerryzh168)
  • Deduplicate int4 workflows
  • Factory function ahd implements decorator for affine quantization dtype

Doc build failing on main

This is the error https://github.com/pytorch/ao/actions/runs/8977543410/job/24656486432?pr=216

Don't have time to debug tonight but cc @svekars who might have some ideas

Run cd docs
Running Sphinx v5.0.0
torchao_version_docs: refs/pull/216/merge
Version: main
making output directory... done
Using Sphinx-Gallery to convert rst text blocks to markdown for .ipynb files.
[autosummary] generating autosummary for: api_ref_dtypes.rst, api_ref_intro.rst, api_ref_kernel.rst, api_ref_quantization.rst, api_ref_sparsity.rst, dtypes.rst, getting-started.rst, index.rst, overview.rst, performant_kernels.rst, quantization.rst, sparsity.rst

Extension error (sphinx.ext.autosummary):
Handler <function process_generate_options at 0x7f0bd5e8cfe0> for event 'builder-inited' threw an exception (exception: no module named torchao.sparsity)
make: *** [Makefile:41: html] Error 2
Error: Process completed with exit code 2.

HQQ Tracker

  • A16W4 axis=1

    • Low hanging fruit we can add to int4wo quant as either a flag or replace the quant method
      • test eval with HQQ axis=1 and compare to existing version
    • if axis = 1 doesn't get enough accuracy improvement, we could also combine with equalization
      • test perf/eval with HQQ axis=1 + equalization
  • A16W4+ axis=1

    • Can quantize certain columns of W to 4/8 bit
      • may be faster to do a 4 bit matmul on all of W and a sparse 8 bit matmul?
      • test perf for int4wo + int8 matmul for n columns
    • HQQ+ end result is an int4wo matmul + lora matmul
      • back of envelope numbers look like 1/3 slowdown over int4 which is still better than int8
      • test perf for int4wo + lora
  • A8W4 axis=1

    • test eval accuracy with HQQ axis=1 and compare to existing version
  • A16W3 and A16W5

    • existing numbers depend on axis = 0, how do these numbers look with axis = 1
      • also relevant whether these numbers scale to llama3 since some quantization difficulty has been reported there
    • get numbers for 3/5 bit quantization with axis = 1, ideally for llama 3

2:4 sparsity + PTQ(int8) model's inference

Are there any runnable demos of using Sparse-QAT/PTQ (2:4) to accelerate inference, such as applying PTQ to a 2:4 sparse LLaMA for inference acceleration? I am curious about the potential speedup ratio this could achieve.
The overall pipeline might be: compressing the Weight matrix using 2:4 sparsity and quantizing it to INT8 format through PTQ/QAT. The Activation matrix should also be quantized to INT8 format through PTQ/QAT. After such processing, the main type of computation would be INT8*INT8.
I would like to know if there is a tutorial document available, as I am a beginner in the field of quantization.
Thx!

[Tracker] General feature requests for torchao

This issue tracks outstanding feature requests for torchao. If you'd like a specific feature to be added to torchao, please comment directly here.

Quantization Techniques (based on planned, new requests)

  • GPTQ
  • HQQ

DTypes

  • fp8
  • mx format

Sparsity APIs

  • int8 + 2:4 sparsity
  • fp8 + 2:4 sparsity

Kernels

  • kernel autotuner for dynamic quant
  • C++ extension for custom kernel - starting with PagedAttention (pytorch/pytorch#121465)
  • CUTLASS w4a8 kernel #64

cc @cpuhrsch

apply_dynamic_quant for vit_b_16

import torch
import torchvision.models.vision_transformer as models

# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)
import torchao


model.eval().cuda().to(torch.bfloat16)
from torchao.quantization import apply_dynamic_quant
apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True

model = torch.compile(model, mode='max-autotune')

input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
model(input_tensor)

causes crash

[...]
    self.out_proj.weight,
  File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1704, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___encoder_layers_encoder_layer_0_self_attention(*(FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
           grad_fn=<NativeLayerNormBackward0>), FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
           grad_fn=<NativeLayerNormBackward0>), FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
           grad_fn=<NativeLayerNormBackward0>)), **{'need_weights': False}):
'DynamicallyPerAxisQuantizedLinear' object has no attribute 'weight'

from user code:
   File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(x, x, x, need_weights=False)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

FloatQuantization subclass

As I was reviewing #223

I was reminded of this PR #214

And I'd be curious what range of floating point numbers we can just express using sublcasses

NF4Tensor uses 8 bits of memory

Hi, I've been playing around with QLoRA using the NF4Tensor class from this great library. But i noticed that the NF4 data type is using 8 bits of memory, where it should be using ~4.1 bits according to the paper. I verified this by initializing a single tensor

>>> t4 = torchao.dtypes.nf4tensor.NF4Tensor.from_tensor(torch.rand([1024, 4096], dtype=torch.bfloat16), 64, 256)
>>> torch.cuda.memory_allocated(), torch.cuda.memory_reserved()
(0, 0)
>>> t4 = t4.cuda()
>>> torch.cuda.memory_allocated(), torch.cuda.memory_reserved()
(4194304, 20971520)

which is 4194304 / (1024 * 4096) * 8 = 8 bits per parameter

I was wondering if this a bug, or is there some intrinsic limitation here? Thanks

Handle NonDynamicallyQuantizableLinear in smoothquant module

Traceback (most recent call last):
  File "C:/Program Files/JetBrains/PyCharm Community Edition 2023.2.1/plugins/python-ce/helpers/pydev/pydevd.py", line 1527, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2023.2.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:\code\foobar\scripts\quantize.py", line 37, in <module>
    swap_linear_with_smooth_fq_linear(model)
  File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
    swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
  File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
    swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
  File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
    swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
  [Previous line repeated 1 more time]
  File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 215, in swap_linear_with_smooth_fq_linear
    target_cls = source_cls_to_target_cls[type(child)]
KeyError: <class 'torch.nn.modules.linear.NonDynamicallyQuantizableLinear'>
python-BaseException

Expected: NonDynamicallyQuantizableLinear layer is skipped (possibly with a warning), or properly handled.
Actual: exception.

It sounds like HDCharles was planning on fixing this more generally: pytorch/pytorch#58969

Semi-Structured Sparsity unsupported for Windows

Running
py pytorch/benchmarks/sparse/benchmark_semi_structured_sparsity.py --mode nvidia-fixed-k --dtype bf16 --backend cutlass
(from #174)

results in
RuntimeError: _sparse_semi_structured_linear: CUTLASS not supported

@jcaip believes it's an issue with Windows, and the best work around would be to dual-boot Linux (which I'll try today!)

Full output:

PS C:\Users\phili\dev> 
Started benchmark: nvidia-fixed-k | dtype: bf16
  0%|                                                                                                             | 0/18 [00:00<?, ?it/s]C:\Users\phili\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\sparse\semi_structured.py:111: UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.sparse module for further information about the project.
  warnings.warn(
Traceback (most recent call last):
  File "C:\Users\phili\dev\pytorch\benchmarks\sparse\benchmark_semi_structured_sparsity.py", line 247, in <module>
    df = pd.DataFrame.from_records(results)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\phili\AppData\Local\Programs\Python\Python312\Lib\site-packages\pandas\core\frame.py", line 2450, in from_records
    first_row = next(data)
                ^^^^^^^^^^
  File "C:\Users\phili\dev\pytorch\benchmarks\sparse\benchmark_semi_structured_sparsity.py", line 220, in <genexpr>
    eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
  File "C:\Users\phili\dev\pytorch\benchmarks\sparse\benchmark_semi_structured_sparsity.py", line 123, in test_tensor
    sparse_output = torch.mm(sA, B)
                    ^^^^^^^^^^^^^^^
  File "C:\Users\phili\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\sparse\semi_structured.py", line 199, in __torch_dispatch__
    return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\phili\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\sparse\_semi_structured_ops.py", line 115, in semi_sparse_mm
    res = A._mm(B_padded)
          ^^^^^^^^^^^^^^^
  File "C:\Users\phili\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\sparse\semi_structured.py", line 439, in _mm
    res = torch._sparse_semi_structured_linear(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: _sparse_semi_structured_linear: CUTLASS not supported
  0%|                                                                                                             | 0/18 [00:07<?, ?it/s]

Custom CUDA extensions

We'd like to make it really easy for people to add support for custom CUDA extensions in ao and there's a few pieces of work we need to do to get there

  • Get an example of a custom cuda extension working #135
  • Add a tutorial for a custom cuda extension #186
  • Make missing cuda toolkit error louder and more obvious #186
  • Ensure cuda extensions are built using manylinux - thanks @seemethere

Follow up work in a separate issue

  • Make an example without premium runners - you can build cuda extensions without a cuda machine per @malfet
  • Add a useful kernel people should be using like paged attention

[Tracker] Outstanding Issues and WIP Features for version 0.1

This issue tracks outstanding issues for a torchao 0.1 release

  • New Functionality

    • Test compatibility with PyTorch 2.2 and 2.3rc1 (@cpuhrsch)
    • Fix tests marked as flaky (@cpuhrsch)
    • int4, int8 weight only quantization support (only need one of the paths to work)
      • path 1: int4, int8 weight quantization subclass API works with TorchTune (@jerryzh168), blocked by tensor subclass save load
      • path 2: int4, int8 weight quantization module swap API works with TorchTune (@jerryzh168), WIP
    • Add GPTQuantizer workflow for 4-bit weight quantization (W4A16) for GPU that works for gpt-fast (and executorch) (@jerryzh168, @HDCharles)
    • Add workflow for 4-bit weight, 8-bit activation quantization (W4A8) with/without GPTQ for executorch (@jerryzh168)
      • without GPTQ path is working, still verifying the GPTQ path
    • NF4 Dtype that works for QLoRA in TorchTune (@cpuhrsch)
    • Fix API so it works with LoRACompatibleLinear
    • Allow apply_quant_api()
      • it currently looks for the children of the module and so doesn't do anything
  • Tutorials/BE

    • Using/Writing a quantization technique using torchao (@jerryzh168)
    • Using kernels written in torchao with PyTorch
    • Replace Int8WeightOnlyQuantizedLinearWeight and Int8DynamicallyQuantizedLinearWeight with a single class
    • Reconsider using class method for Int8DynamicallyQuantizedLinearWeight.from_float
    • Remove / guard catch all forward args, kwargs for module swap API
    • Land Tutorial pytorch/tutorials#2730
  • If time permits (or v0.2)

    • Enable test_8da4w_quantize for 2.4 @jerryzh168
    • 4-bit quantization CPU perf numbers
    • Feature parity between module swap api and subclass api
    • Align smoothquant api with others
      • Add high level auto quant API for int8 dynamic and weight-only quantization with benchmarks (@HDCharles)

[Question] MBU in automated CI?

Hi folks, thanks for the great work.

With #135 merged, vLLM could see benefit from torch.compile backend given compiler-native integration with PagedAttention kernels.

Is there an easy way to see what the latest/nightly MBU is for torch compile on say, H100 / Llama3 70B?

Also interested in cold start compile time

cc @msaroufim

Sparsity OSS colab tracker

We're putting together a loose RFC for our general plans that should be out shortly.

However we know that we want to work with the researchers / OSS to land advanced pruning algorithms into torchao.
These pruning algorithms should extend the Sparsifier class found in torch.ao.pruning.

We're also interested in potentially adding additional fast sparse kernels into torchao. This requires some additional discussion, in particular if we want to land these kernels as is for eager mode support, or to try and generate these kernels with triton.

[RFC] Plans for sparsity

Summary

Sparsity, like quantization, offers increased model performance at the expense of some model quality. However, it is not as widely used / researched as a technique, despite offering similar performance benefits. With the recent explosion in model sizes in GenAI, and with quantization pushing 1-bit limits, there has been renewed interest in sparsity, specifically for GPU backend sparsity patterns.

The parallel nature of GPU backends makes accelerating unstructured sparsity difficult. However, there exist specific sparsity patterns (block-wise, semi-structured) that are more amenable to acceleration on GPUs. Over the last year, we’ve integrated these fast sparse kernels into PyTorch Core, so that all users can show up to a with just a few lines of code:

Our goal for torchao.sparsity is to drive research / adoption of these GPU sparsity patterns.

We feel that the main problem current researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like: When should I mask? When/how should I store the compressed representation? Do I want in-place or out-of-place mask updates? How can I call sparse matmul instead of dense?

We hope to change that by providing tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We feel like the above problems can be solved once, by torchao, letting researchers focus on pushing sparse kernel performance or more accurate pruning algorithms.

We're also hoping to create a new extension point by releasing the workflows we have designed with xFormers that enable accelerated sparse training, not just sparse inference.
As such, we plan on launching torchao.sparsity with the following features in v0.2:

However, we’d like feedback from the community to set the longer-tem vision of sparsity. Also fee free to chime in with any other thoughts you want to share!


Pruning Algorithms

We plan to host a set of OSS pruning algorithms in torchao. These pruning algorithms should extend the torch.ao.pruning.BaseSparsifier class, like WandaSparsifier. We welcome community contributions for pruning algorithms, provided they extend the BaseSparsifier.

Open Questions:

  • Are there changes that need to be made to Sparsifier?
    • Pruning uses parameterizations ( FakeSparsity ), should we switch to MaskedTensor?
    • Global mask updates are difficult to support, is this something researchers care about?
  • We are designing this with first-time users of sparsity in mind, not researchers, does that resonate with the community?
  • What pruning algorithms are interesting for the community? RigL seems like it would be good to have, any others?

Recipes / Benchmarks

We have often found pruning to be very model specific, with little generalization across domains. As such we hope to land sparse training recipes for specific models / datasets, showing how different pruning algorithms can be used. We are specifically interested in recipes that compose with quantization.

Additionally, we hope that these benchmark numbers can help first-time users of sparsity better understand the tradeoffs involved and encourage researchers to contribute SOTA pruning algorithms.

Open Questions:

  • We plan to focus on vision models for now, should we focus on LLMs?
  • What benchmarks / datasets are interesting to the community? It looks like ViT on ImageNet is the most common architecture.
  • Does the community feel there is value in having a suite of sparse microbenchmarks for the different sparsity patterns or just E2E results?
  • Does the community feel that there is value in having a suite of different pruning (accuracy) benchmarks, for something like unstructured pruning, for comparisons sake?
  • If yes to the above two questions, are there specific combinations that are interesting?

Accelerated Sparse Training

While much work has been done on sparsity for inference, sparsity for training has remained much more challenging. Thanks to the work done by xFormers, we’ve upstreamed fast runtime semi-structured sparsification kernels into PyTorch Core, which allow for prune -> compress -> sparse_mm to happen faster than dense matmul. We also aim to release an example of accelerated sparse training for the OSS community to extend.

  • Does the community want us to focus more on fast sparse training or pruning workflows, where you start with an existing trained model?
  • Should we extend this with sparse M:N kernels? This would allow for more flexible accuracy at the expense of some performance.
  • The masking mechanism is different from the torch.ao.pruning masking mechanism (FakeSparsity), should we unify the two?

Performant Sparse Kernels

There are additional sparsity patterns that may be supported on GPUs, which would require additional fast sparse kernels. We hope that torchao can be a staging ground for these kernels. We plan to upstream these kernels to Core as we see fit, depending on adoption.

Some initial options are:

  • Block sparse + 2:4 sparse kernel - Combining 2:4 sparsity with block sparsity for maximum speedups.
  • SHFL-BW kernel - These kernels add a row-wise permutation before block sparsity, to allow for a more flexible sparsity pattern. It would be interesting to see if this shuffle could be used for M:N sparsity as well.
  • M:N / Sparse fan-in kernel - These are similar to 2:4 sparse kernels, but generalized to N:M. While they do not offer the same hardware acceleration as 2:4 sparsity, you can still get memory speedups by sending a compressed representation.

Open Questions:

  • What about load as sparse, compute as dense kernels?
  • What about other backends like COO CPU kernels for unstructured sparsity? We believe that we should focus on these M:N / block-sparse GPU patterns in particular.
  • Can we generate these kernels with torch.compile rather than hand writing them?

cc @supriyar @cpuhrsch @msaroufim @pytorch-labs/team-superblock @danthe3rd @mklasby @ngc92 @hgyhungry

Building torchao from source installs unnecessary torch and nvidia packages every time

To reproduce

conda create -n test_ao python=3.10
conda activate test_ao
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install git+https://github.com/pytorch/ao -v

Outputs (the packages are already cached in this case)

Using pip 24.0 from /home/ubuntu/miniconda3/envs/test_ao/lib/python3.10/site-packages/pip (python 3.10)
Collecting git+https://github.com/pytorch/ao
  Cloning https://github.com/pytorch/ao to /tmp/pip-req-build-bcuh0mqg
  Running command git version
  git version 2.34.1
  Running command git clone --filter=blob:none https://github.com/pytorch/ao /tmp/pip-req-build-bcuh0mqg
  Cloning into '/tmp/pip-req-build-bcuh0mqg'...
  Running command git rev-parse HEAD
  b91b6be24afd1220331790ff0866f5b091165cd5
  Resolved https://github.com/pytorch/ao to commit b91b6be24afd1220331790ff0866f5b091165cd5
  Running command git rev-parse HEAD
  b91b6be24afd1220331790ff0866f5b091165cd5
  Running command pip subprocess to install build dependencies
  Collecting setuptools
    Using cached setuptools-69.5.1-py3-none-any.whl.metadata (6.2 kB)
  Collecting wheel
    Using cached wheel-0.43.0-py3-none-any.whl.metadata (2.2 kB)
  Collecting ninja
    Using cached ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (5.3 kB)
  Collecting torch
    Using cached torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
  Collecting filelock (from torch)
    Using cached filelock-3.14.0-py3-none-any.whl.metadata (2.8 kB)
  Collecting typing-extensions>=4.8.0 (from torch)
    Using cached typing_extensions-4.11.0-py3-none-any.whl.metadata (3.0 kB)
  Collecting sympy (from torch)
    Using cached sympy-1.12-py3-none-any.whl.metadata (12 kB)
  Collecting networkx (from torch)
    Using cached networkx-3.3-py3-none-any.whl.metadata (5.1 kB)
  Collecting jinja2 (from torch)
    Using cached jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
  Collecting fsspec (from torch)
    Using cached fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
  Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
    Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
  Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
    Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
  Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
    Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
  Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
    Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
  Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
    Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
  Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
    Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
  Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
    Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
  Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)
    Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
  Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)
    Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
  Collecting nvidia-nccl-cu12==2.20.5 (from torch)
    Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)
  Collecting nvidia-nvtx-cu12==12.1.105 (from torch)
    Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)
  Collecting triton==2.3.0 (from torch)
    Using cached triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
  Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)
    Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
  Collecting MarkupSafe>=2.0 (from jinja2->torch)
    Using cached MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
  Collecting mpmath>=0.19 (from sympy->torch)
    Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
  Using cached setuptools-69.5.1-py3-none-any.whl (894 kB)
  Using cached wheel-0.43.0-py3-none-any.whl (65 kB)
  Using cached ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
  Using cached torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
  Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
  Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
  Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
  Using cached triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)
  Using cached typing_extensions-4.11.0-py3-none-any.whl (34 kB)
  Using cached filelock-3.14.0-py3-none-any.whl (12 kB)
  Using cached fsspec-2024.3.1-py3-none-any.whl (171 kB)
  Using cached jinja2-3.1.4-py3-none-any.whl (133 kB)
  Using cached networkx-3.3-py3-none-any.whl (1.7 MB)
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
  Using cached MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
  Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
  Installing collected packages: ninja, mpmath, wheel, typing-extensions, sympy, setuptools, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch
  Successfully installed MarkupSafe-2.1.5 filelock-3.14.0 fsspec-2024.3.1 jinja2-3.1.4 mpmath-1.3.0 networkx-3.3 ninja-1.11.1.1 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.1.105 setuptools-69.5.1 sympy-1.12 torch-2.3.0 triton-2.3.0 typing-extensions-4.11.0 wheel-0.43.0
  Installing build dependencies ... done

On my machine, it takes 30s just to install the cached packages. Note that this is done every time I re-install torchao. During development, it's quite annoying to have this extra 30s every time I need to re-compile CUDA/C++ code (pip install -e . only works for python code).
If this is the first time installing torchao from source, there will be extra time downloading the packages (which are huge).

The culprit seems to be torch being a build-system.requires in pyproject.toml. Perhaps this is a limitation of pip not being able to recognize torch from conda? During the build process, it's also not clear if it is using existing torch (from conda) or pip-torch (may cause issues if the two versions mismatch? I'm using the latest version so issues may not arise).

[RFC] Plans for torchao

Summary

Last year, we released pytorch-labs/torchao to provide acceleration of Generative AI models using native PyTorch techniques. Torchao added support for running quantization on GPUs, including int8 dynamic quantization (W8A8) and weight-only quantization (int8 and int4) that were composable with torch.compile. Combined, the APIs launched in torchao were able to power SOTA generative AI models across multiple modalities: Segment Anything, Stable Diffusion, and LLaMa.
The results were showcased in these blog posts -
https://pytorch.org/blog/accelerating-generative-ai/,
https://pytorch.org/blog/accelerating-generative-ai-2/,
https://pytorch.org/blog/accelerating-generative-ai-3/

Our investment in torchao is to accelerate Generative AI, using native PyTorch features, ensuring composability with torch.compile.

In 2024, we plan to adopt the following strategy for development of torchao

  • We will launch torchao with the most important quantization techniques for LLMs and other GenAI models via a simple UX. Examples - GPTQ, AWQ, int8 dynamic quant.
  • We will stay on top of SOTA kernels within these spaces through to PTC and commit cpu/gpu kernels ourselves as necessary. Torchao will host a limited set of performant kernels for server (cpu/gpu) and executorch, with a clear recommendation on how to integrate and run inference on these backends.
  • Torchao will host non-standard dtypes, implemented via tensor subclasses. Examples - nf4, any4, mx4
  • Following the PyTorch design principle, the offerings of torchao will be usable and simple, including setup, dependencies, API surfaces.
  • We actively engage with the community/researchers to contribute new quantization techniques in native PyTorch code and developers to author performant kernels for these into torchao for different backends. An example would be to upstream the kernels built by the CUDA_MODE community into torchao.
  • As the code gets more mature/based on community demand - we will upstream techniques/kernels into PyTorch Core.

Let’s dive deeper into some of the coverage areas mentioned above.

Emerging dtypes

Dtypes like NF4, MX4, groupwise quantized int4 are used for implementing various optimization techniques in the models. Last year, we posted a plan on how we wish to support these dtypes in PyTorch. In torchao, we will host tensor subclass based implementation of dtypes, existing examples include uint4 and NF4 that users can use for their own quantization techniques or override the implementation to support other dtypes that might be useful.
Moreover, users don’t need to write triton or cuda kernels for their custom dtypes. The implementation can be in python and torch.compile will take care of generating performant kernels under the hood.

Quantization techniques

Quantization can be done on only weights or weights+activations. Typically LLM quantization techniques for BS 1 (memory BW bound) use weight-only quantization techniques. But for larger batch sizes, or longer context length cases or for general throughput bound models quantizing the activations is also beneficial. Quantization, however, impacts the model accuracy and researchers have published techniques to mitigate this accuracy impact which currently exist externally as one repository per technique.

In torchao, we will plan to support the following class of techniques using PyTorch, made available via a simple UX and following the one-file-per-technique principle.

LLM weight only quantization techniques

Post training quantization
The two most popular techniques externally are GTPQ and AWQ, available via AutoGPTQ and AutoAWQ which include the technique as well as the performant kernels for faster quantization ops.
To that end, we will start by re-implementing the GPTQ and AWQ techniques into torchao using PyTorch via a simple/intuitive UX that supports saving/loading of quantized models, while realizing the memory savings on disk. Some open questions we need to address here include -
How much VRAM will be required for different quantization techniques
How do we convert to-from weights quantized for different backends (cpu and gpu today use different weight packing format)

In the future, as more interesting and cutting edge techniques are introduced, researchers can directly implement them in torchao or our team can re-implement them in PyTorch.

Weight and activation quantization techniques

Post training quantization
We’ve already implemented W8A8 quantization via the int_mm kernel in core. This has shown speedup on models like SAM, SDXL without any impact to model accuracy and can be turned on via a simple one-line UX implemented via module swap or tensor subclass.

However the challenge here is that some smaller layer shapes might not benefit from quantization due to the overhead in quantizing and dequantizing the activation tensors. Users can either statically ignore quantizing these layers or have a higher level API that figures out which layers are sensitive to quantization. We plan to provide a higher level API via the auto quantizer that applies this technique to the layers that stand to benefit the most to provide the benefits of quantization without having to worry too much about the configs to use.

Quantization aware training
Techniques here require access to fine-tuning, to tune the model to reduce accuracy impact of quantization. Recently, research like LLM-QAT is promising, showing that we can go down to W4A8 and 4-bit KV cache for LLMs. Moreover, newer lower bit techniques like AQLM, Quip# also include a component of fine-tuning to improve the model accuracy.

We will include the APIs and workflow to enable users to do QAT on LLMs, starting with implementing the LLM-QAT paper in torchao and further extending it to support other dtypes like MX4.

Optimized kernels

Kernels
Optimized kernels are key to making models run faster during inference. Today, in core we already have performant kernels like int_mm and 4-bit weight quantization kernels for cpu (via intel) and gpu (via tinygemm). torchao will host performant kernels that will work with different backends with a guide on how to plug in these kernels into PyTorch models via the custom ops API. These kernels will compose with torch.compile, with the expectation that the user is expected to write a meta kernel implementation for this. For executorch, the expectation is that if the user provides a kernel that works with executorch then it should also work in eager mode.

We will also directly engage with the community, to upstream their performant kernels into torchao.

Autotuner

In order to use any CUDA kernel efficiently, we'll need to pick the right kernel hyperparameters. For an eager mode kernel, the same is true as well. A kernel autotuner will help here. We expect that the auto quantizer along with the kernel autotuner will make int8 dynamic quantization and int8/int4 weight-only quantization more usable and performant. A WIP example of what this might look like can be found here.

Release engineering

Shipping optimized, custom kernels requires extensibility mechanisms and release channels. We have custom operator support that integrates broadly, but our release mechanism might need to be optimized. It can be quite difficult to ship custom binaries across a broad range of operating systems and accelerators.

Conversion to-from popular model formats

We can add a conversion util from popular model storage formats like gguf into PyTorch’s state_dict format. This will enable users to take a pre-existing quantized model from llama.cpp and have it run via PyTorch eager mode for desktop cpu/gpu and executorch for on-device cases. We’ll share more details here soon.

Pruning

In addition to quantization, we’ve seen promising results with sparsity as well on GPUs. We will share more updates on what torchao will host for the space of sparsity/pruning in the near future.

We'd love to hear any feedback or questions from the OSS community on this RFC. Thank you!

cc @msaroufim @cpuhrsch @jerryzh168 @HDCharles @andrewor14 @jcaip @jisaacso

[RFC] Plans for LLM QAT

Following the recent success of the LLM-QAT paper, our high-level goal is to provide a PyTorch native workflow for LLM quantization-aware training (QAT) leveraging the quantization primitives and kernels provided by torchao, which is planned to become the de facto OSS library for AO techniques and kernels in PyTorch across different platforms (#47). We also hope to eventually integrate with TorchTune, a recently open-sourced library for fine-tuning and experimenting with LLMs, to provide an end-to-end flow that supports both finetuning and QAT.

Workstream 1: Edge Devices

Executorch provides a mechanism for quantizing Llama2 using post-training quantization (PTQ) techniques such as GPTQ, and lowering it to backends like XNNPACK. The main goal of this workstream is to provide a QAT drop-in replacement for GPTQ but with superior accuracy, starting with Llama2 7b using the following quantization/training configurations:

  • Linear weights: 4-bit per channel grouped symmetric static quantization
  • Linear activations: 8-bit per token symmetric dynamic quantization
  • Not “data-free”: Use original dataset unlike in the LLM-QAT paper

We plan to adopt the same eager mode quantization implementation used by the PTQ flow. In the future, if we decide to experiment with static quantization for activations for example, then we can explore using the PT2 Export QAT flow.

Workstream 2: Explore new quantization methods

This workstream is largely backend agnostic; our goal is to motivate the backends (mobile or server CPU/GPU) to build the relevant kernels once we have demonstrated the initial success of a particular quantization configuration. There is a large design space we can experiment with summarized below. The suggested quantization and training techniques are primarily motivated by the LLM-QAT paper, but also by ongoing developments across the industry.

We can start with the following dimensions:

  • KV-cache quantization: 4- or 8-bit KV-cache quantization can alleviate throughput bottlenecks with long sequences, and this has been shown (in the LLM-QAT paper) to hurt QAT a lot less than PTQ in terms of accuracy.
  • Custom dtypes: The latest Hopper and upcoming Blackwell GPU generations no longer support int4 tensor cores, and so int4 kernels may not be as performant as other 4-bit dtypes. For example, both nf4 and MX4 promise higher fidelity than any a priori fixed quantization like int4. Experimenting with newer dtypes in QAT may lead to further accuracy gains.
  • Lower bit-widths: 2- or 3-bit weight quantization can help further lower memory footprint and speed up inference. There have been PTQ attempts at such bit-widths (e.g. Quip#, AQLM), but QAT has the potential to further mitigate the accuracy degradation.

Workstream 3: Server GPU Inference

This is an extension of the recent gpt-fast efforts to quantize Llama but for QAT. An important goal here is to reuse the same quantization primitives as Workstream 1 to unify the two flows as much as possible. We can start with the following quantization configurations:

  • Int4 weight-only quantization. This was the focus last half for Llama2, which targeted batch size 1 local chat agent use cases. This particular workload is memory bound, not compute bound, when run on GPUs, and so quantizing the activations here may not be particularly beneficial. For QAT, we can perform the same weight-only quantization for better accuracy.
  • Int4 weight quantization + int8 activation dynamic quantization, similar to Workstream 1. One advantage here is we will have numerical baselines from the ExecuTorch workstream to compare against. However, as explained above, it may not make sense for Llama2 batch size 1 use cases for GPUs, so this configuration may be more suitable for larger batch sizes or other more compute bound models. The plan here is to be able to leverage ongoing efforts to provide mixed 4-bit / 8-bit GEMM in cutlass: NVIDIA/cutlass#1413.
  • MX4 weight + activation quantization. Please see the previous section under Custom dtypes for more details.

Docs Revamp

Just listing out all the issues I'm seeing with our docs, feel free to pick something up and fix it

  • In the main README when we talk about features we should link to usage instructions and code not papers @msaroufim
  • We don't have an NF4 tutorial @drisspg
  • We don't have a wanda tutorial @jcaip
  • Sparsity we mention tons of algorithms but should suggest a simple one people should start with @jcaip
  • Our main goals are performance w/ composability with torch.compile and FSDP + performance. And also easy packaging for wide reach @msaroufim
  • Mention HQQ, GaLore and prototype folder somewhere in main docs @msaroufim
  • A doc for how to register a new custom OP for both C++ and Triton @msaroufim
  • AOT inductor and no python overhead tutorial @jerryzh168
  • Update autoquant tutorial to work OOB with gpt-fast, should be some copy pastable snippet or some other model on HF @HDCharles
  • Smoothquant tutorial is placeholder code, needs an actual runnable snippet @HDCharles
  • We don't really articulate the benefits of tensor subclasses anywhere @drisspg ?
  • Mention tinygemm @msaroufim

Run semi-structured spare benchmarks on consumer hardware

2:4 sparisty is only supported on Ampere+ , we've only run benchmarks with A100s, but Phil (@philipbutler) has access to consumer GPUs that could also take advantage of sparse acceleration as well.

Steps to get numbers:

  1. install pytorch pip nightlies from here
  2. verify that your consumer GPU supports semi-structured sparsity
import torch
from torch.sparse import to_sparse_semi_structured
to_sparse_semi_structured(torch.ones(256, 256).half().cuda())
  1. Clone pytorch and get benchmark script:
  2. Run benchmarks. For now, let's see if the nvidia-fixed-mn / nvidia-fixed-k benchmarks still show speedups.
python benchmarks/sparse/benchmark_semi_structured_sparsity.py  --mode nvidia-fixed-k --dtype bfloat16 --backend cutlass
python benchmarks/sparse/benchmark_semi_structured_sparsity.py  --mode nvidia-fixed-mn --dtype bfloat16 --backend cutlass

Afterwards, it would be great to get benchmarks for the ViT-B shapes found here: https://github.com/pytorch/ao/blob/main/benchmarks/sam_vit_b_shapes.csv

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.