Coder Social home page Coder Social logo

jerryzh168 / ao Goto Github PK

View Code? Open in Web Editor NEW

This project forked from pytorch/ao

1.0 0.0 0.0 1.28 MB

torchao: PyTorch Architecture Optimization (AO). A repository to host AO techniques and performant kernels that work with PyTorch.

License: BSD 3-Clause "New" or "Revised" License

Python 99.33% Shell 0.06% Cuda 0.58% C++ 0.03%

ao's Introduction

torchao: PyTorch Architecture Optimization

This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue

Introduction

torchao is a PyTorch library for quantization and sparsity.

Get Started

Installation

torchao makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.

Stable Release

pip install torchao

Nightly Release

pip install torchao-nightly

From source

git clone https://github.com/pytorch/ao
cd ao
pip install .

Quantization

import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)

Sparsity

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier

# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)

# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                 sparse_block_shape=(1,4),
                                 zeros_per_block=2)

# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights

# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

To learn more try out our APIs, you can check out API examples in

Supported Features

  1. Quantization algorithms
  2. Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
  3. Support for lower precision dtypes such as
  4. Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees

Our Goals

  • Composability with torch.compile: We rely heavily on torch.compile to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels
  • Composability with FSDP: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently.
  • Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
  • Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
  • Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices

Integrations

torchao has been integrated with other libraries including

  • torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
  • Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
  • HQQ leverages our int4mm kernel for low latency inference

Success stories

Our kernels have been used to achieve SOTA inference performance on

License

torchao is released under the BSD 3 license.

ao's People

Contributors

cpuhrsch avatar jerryzh168 avatar hdcharles avatar msaroufim avatar supriyar avatar jcaip avatar svekars avatar andrewor14 avatar drisspg avatar rohan-varma avatar weifengpy avatar jeromeku avatar aakashapoorv avatar manuelcandales avatar larryliu0820 avatar xia-weiwen avatar dependabot[bot] avatar usingtcnower avatar leslie-fang-intel avatar

Stargazers

Bryan Kevin Jones avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.