Coder Social home page Coder Social logo

lucidrains / mixture-of-attention Goto Github PK

View Code? Open in Web Editor NEW
96.0 8.0 3.0 34.92 MB

Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts

License: MIT License

Python 100.00%
artificial-intelligence attention-mechanisms deep-learning mixture-of-experts routed-attention

mixture-of-attention's Introduction

Mixture-of-Attention

Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts

Learned from researcher friend that this has been tried in Switch Transformers unsuccessfully, but I'll give it a go, bringing in some learning points from recent papers like CoLT5.

In my opinion, the CoLT5 paper basically demonstrates mixture of attention already for 2 experts. This just has to be generalized to greater than 2 experts, and for autoregressive case. Local attention branch would just be a special case of one expert with fixed routing. If I route only half the tokens, that would lead to a savings of 4x. If I can show even ~4 experts being better than 1 attention, that should be a win.

Appreciation

  • Stability and ๐Ÿค— Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • einops for making tensor manipulation fun and easy

Install

$ pip install mixture-of-attention

Usage

import torch
from mixture_of_attention import MixtureOfAttention

mixture_of_attn = MixtureOfAttention(
    dim = 512,
    dim_context = 256,
    num_routed_queries = 16,
    num_routed_key_values = 16,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()

context = torch.randn(1, 512, 256)
context_mask = torch.ones((1, 512)).bool()

mixture_of_attn(x, context = context, mask = mask) # (1, 1024, 512)

Autoregressive flavor

import torch
from mixture_of_attention import MixtureOfAutoregressiveAttention

mixture_of_attn = MixtureOfAutoregressiveAttention(
    dim = 512,
    local_attn_window_size = 64,       # local attention window size
    routed_window_size = None,         # will be set to the same as local_attn_window_size if None. ideally less than or equal to local attention window size for full receptive field
    num_routed_queries = 12,
    num_routed_key_values = 12,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1023, 512)

out = mixture_of_attn(x) # (1, 1023, 512)

Todo

  • allow for local attention to be automatically included, either for grouped attention, or use LocalMHA from local-attention repository in parallel, weighted properly

  • make it work for autoregressive

  • try dynamic routing tokens, using projection of masked mean-pooled queries

  • try out https://arxiv.org/abs/2210.05144

Citations

@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Wright2015CoordinateDA,
    title   = {Coordinate descent algorithms},
    author  = {Stephen J. Wright},
    journal = {Mathematical Programming},
    year    = {2015},
    volume  = {151},
    pages   = {3-34}
}
@article{Schmitzer2016StabilizedSS,
    title   = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
    author  = {Bernhard Schmitzer},
    journal = {ArXiv},
    year    = {2016},
    volume  = {abs/1610.06519}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}

mixture-of-attention's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.