Coder Social home page Coder Social logo

mamba.py's Introduction

mamba.py 🐍 : a simple parallel scan implementation

A straightfoward implementation of Mamba in PyTorch with a simple parallel scan implementation, offering an major speedup over a a sequential implementation. It combines the ease of read with good performances.

speed comparison

This repo contains a simple and readable code implementing the Mamba architecture in pure PyTorch. Its primary goal is educational.

a python and a mamba

The repo is organized as follows :

  • pscan.py : a PyTorch implementation of Blelloch's parallel scan
  • mamba.py : the Mamba model, as described in the paper. It is numerically equivalent (initialization, forward and backward pass).
  • mamba_lm.py : encapsulates a Mamba model in order to use it as a language model
  • 📁 docs : a folder containing annotated explanations about the code, focusing on the parallel scan
  • 📁 examples : two examples of how to use the Mamba model.

Usage

The most basic usage is to use the Mamba object (mamba.py), which implements a simple Mamba model given a configuration. No embedding, no head : input is (B, L, D) and output is (B, L, D) as well.

import torch
from mamba import Mamba, MambaConfig

config = MambaConfig(d_model=16, n_layers=2)
model = Mamba(config)

B, L, D = 2, 64, 16
x = torch.randn(B, L, D)
y = model(x)

assert y.shape == x.shape

The class MambaLM (mamba_lm.py) builds on the Mamba object and offers a classic API for language models. It can be used as follows :

from mamba_lm import MambaLM, MambaLMConfig

config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=32000)
model = MambaLM(config)

x = torch.randint(high=32000, size=(16, 64))
logits = model(x) # (B, L, vocab_size)

It simply encapsulates a Mamba object with an embedding layer, a final normalization and a language modeling head.

## Examples There are two basics examples available :

  • example_llm.ipynb : load a Mamba model with pretrained weights (from 130M to 2.8B from HuggingFace)
  • example_e2e_training.ipynb : an end-to-end training example where a Mamba model is employed as a world model for a simple 3-3 grid game (training is not completed, the model should be larger).

Sources and where to learn more

TODOs

  • docs
  • a step function, used for (auto-regressive) inference.
  • unfold the for-loops in pscan.py to achieve better performance (see François Fleuret's pscan) (although this will sacrifice readability of bit)
  • write a reverse parallel scan specifically for the backward pass. (For now, we have to flip the array before and after the scan).
  • use torch.compile(). As far as I tested, it doesn’t work for now. It seems it isn’t happy with the custom PScan autograd function. Need to investigate. (see PR#1)

mamba.py's People

Contributors

alxndrtl avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.