Coder Social home page Coder Social logo

flashy_linear_attention's Introduction

Flashy Linear Attention

Forward and backward Triton kernels implementing linear attention in the style of Flash Attention. Also Includes a kernel where query and key tensors have a different head dimension, intended to be used with Taylor Series Linear Attention as introduced here.

Installation

git clone https://github.com/fattorib/flashy_linear_attention
cd flashy_linear_attention
pip install .

Note: As of January 2024, Triton 2.2.0 offers better performance than the pre-installed Triton 2.1.0 that comes with PyTorch 2.1.x. After installing flashlinear you can manually upgrade Triton after with:

pip install -U triton==2.2.0

Pip might complain that PyTorch doesn't support Triton 2.2.0, but it does.

Use

A full linear attention implementation would include denominator scaling. There aren't huge speed/memory gains to be had from fusing this in (and it also makes the fused backward pass more complicated) so we leave it to PyTorch:

from flashlinear import LinearAttention

def triton_linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, eps = 1e-5):
    qkv_out = LinearAttention.apply(q, k, v)
    denom = (q * k.cumsum(-2)).sum(dim=-1, keepdim=True) + eps
    return qkv_out / (denom)

If you want to perform linear attention where the queries and keys have a larger dimensionality than the values, you can use LinearAttentionSmallVHD:

from flashlinear import LinearAttentionSmallVHD

def triton_linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, eps = 1e-5):
    qkv_out = LinearAttentionSmallVHD.apply(q, k, v)
    denom = (q * k.cumsum(-2)).sum(dim=-1, keepdim=True) + eps
    return qkv_out / (denom)

Benchmarks

Compared against naive linear attention wrapped with torch.compile and CausalDotProduct from fast-transformers.

Forward:

Forward+Backward:

Further optimization and tuning is needed to reach optimal performance, but this kernel is fast enough for now.

Tests

pytest flashlinear/test_*

Todos

  • Use triton.autotune instead of hand-picked configs.

References

@misc{katharopoulos2020transformers,
      title={Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention}, 
      author={Angelos Katharopoulos and Apoorv Vyas and Nikolaos Pappas and François Fleuret},
      year={2020},
      eprint={2006.16236},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{dao2022flashattention,
      title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, 
      author={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher Ré},
      year={2022},
      eprint={2205.14135},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{buckman2024,
  author = {Buckman, Jacob and Gelada, Carles},
  publisher = {Manifest AI},
  title = {Linear {Transformers} {Are} {Faster} {After} {All}},
  date = {2024-01-05},
  langid = {en}
}

flashy_linear_attention's People

Contributors

fattorib avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.