Coder Social home page Coder Social logo

graphcore-research / flash-attention-ipu Goto Github PK

View Code? Open in Web Editor NEW
2.0 4.0 0.0 589 KB

Poplar implementation of FlashAttention for IPU

License: MIT License

C++ 74.04% Python 22.37% Makefile 1.97% Shell 1.61%
deep-learning flash-attention flash-attention-2 graphcore poplar transformers ipu pytorch

flash-attention-ipu's Introduction

FlashAttention (IPU)

Poplar implementation of FlashAttention for IPU

Quickstart

# Tested on Poplar SDK 3.3.0+7857, Ubuntu 20.04, Python 3.8, torch 2.0.1
python -m pip install git+ssh://[email protected]/graphcore-research/flash-attention-ipu.git

Demo

nanoGPT example

Usage

from flash_attention_ipu import flash_attention_qkv_packed

# For user-controlled chunking on IPU
class ChunkedAttention(torch.nn.Module):
  def __init__(self, num_chunks_q, num_chunks_kv):
    super().__init__()
    self.num_chunks_q = num_chunks_q
    self.num_chunks_kv = num_chunks_kv

  def forward(self, qkv):
    return flash_attention_qkv_packed(
      self.qkv.reshape(3, -1, *self.qkv.shape[-2:]),
      num_chunks_q=self.num_chunks_q,
      num_chunks_kv=self.num_chunks_kv
    )

# For automated chunking on IPU
import flash_attention_ipu.auto
import torch.nn.functional as F

# flash_attention_ipu.auto overrides F.scaled_dot_product_attention
class SDPAttention(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, q, k, v):
    return F.scaled_dot_product_attention(
      q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
    )

Development

git clone [email protected]:graphcore-research/flash-attention-ipu.git
cd flash-attention-ipu
make

#Optional
./build/tests

Background

FlashAttention solves a key bottleneck of dot product attention on GPUs, namely the reading and writing of the attention matrix between HBM and L2 Cache.

This becomes particularly problematic for training large models on long sequences as backpropagation requires either 1) storing the attention matrix for each layer, which can quickly exceed GPU maximum memory, or 2) recomputing the attention matrix, which dominates FLOPs when the sequence is long enough.

FlashAttention overcomes this bottleneck by chunking the query, key, and value tensors along the sequence dimension and computing the attention matrix in chunks using an online softmax algorithm.

For small enough chunks, it is not necessary to read and write the attention matrix from HBM, and all intermediate tensors can fit in SRAM. As a result, this is both a memory-efficient and IO-efficient algorithm for computing dot-product attention.

What relevance does this have for IPUs where the whole model is in SRAM?

A Graphcore IPU chip has about 900 MB of SRAM split between 1472 tiles. Each tile can communicate with the others via an all-to-all exchange. This all-to-all exchange makes it possible for operations such as large matrix multiplications to get close to peak FLOPs for an IPU (around 350 TFLOPs).

Assuming every tile is being used for computation when the entire model fits in IPU SRAM, performance is limited by how much data needs to be exchanged across tiles. As such, a good FlashAttention implementation for the IPU minimises both memory usage and data exchange across tiles.

This initial attempt aims to keep memory consumption low using dynamic slicing and outlined graphs. We also aim to keep exchange reasonably small using off-the-shelf tile mappings of tensors. We leave more customised tile mappings and further improvements to memory usage in future releases.

License

Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License.

Includes the third-party Catch2 submodule for unit-testing, licensed under the BSL-1.0 License.

flash-attention-ipu's People

Contributors

jayniep-gc avatar lyprince avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.