Coder Social home page Coder Social logo

lucidrains / x-transformers Goto Github PK

View Code? Open in Web Editor NEW
4.1K 51.0 354.0 37.99 MB

A simple but complete full-attention transformer with a set of promising experimental features from various papers

License: MIT License

Python 100.00%
artificial-intelligence deep-learning attention-mechanism transformers

x-transformers's Introduction

x-transformers

PyPI version

A concise but fully-featured transformer, complete with a set of promising experimental features from various papers.

Install

$ pip install x-transformers

Usage

Full encoder / decoder

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))

loss = model(src, tgt, mask = src_mask) # (1, 1024, 512)
loss.backward()

Decoder-only (GPT-like)

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()

model(x) # (1, 1024, 20000)

GPT3 would be approximately the following (but you wouldn't be able to run it anyways)

gpt3 = TransformerWrapper(
    num_tokens = 50000,
    max_seq_len = 2048,
    attn_layers = Decoder(
        dim = 12288,
        depth = 96,
        heads = 96,
        attn_dim_head = 128
    )
).cuda()

Encoder-only (BERT-like)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 1024, 20000)

State of the art image classification (SimpleViT)

import torch
from x_transformers import ViTransformerWrapper, Encoder

model = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

img = torch.randn(1, 3, 256, 256)
model(img) # (1, 1000)

Image -> caption

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

PaLI, state of the art language-vision model

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

# do the above for many steps on a 17B parameter model
# attention is all you need

Dropouts

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    emb_dropout = 0.1,         # dropout after embedding
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        layer_dropout = 0.1,   # stochastic depth - dropout entire layer
        attn_dropout = 0.1,    # dropout post-attention
        ff_dropout = 0.1       # feedforward dropout
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Features

Flash Attention

What originally started off as a short paper from Markus Rabe culminated as a practical fused attention CUDA kernel, named Flash Attention by Tri Dao.

The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.

Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.

MetaAI has recently added the ability to use Tri Dao's CUDA kernel through the scaled_dot_product_attention function in Pytorch 2.0. (They also have a mem_efficient attention, which is identical to flash attention design, just that the tiles are traversed differently)

Llama was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).

You can use it in this repository by setting attn_flash to True and enjoy the immediate memory savings and increase in speed.

ex.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_flash = True # just set this to True if you have pytorch 2.0 installed
    )
)

Augmenting Self-attention with Persistent Memory

https://arxiv.org/abs/1907.01470

Proposes adding learned memory key / values prior to attention. They were able to remove feedforwards altogether and attain similar performance to the original transformers. I have found that keeping the feedforwards and adding the memory key / values leads to even better performance.

from x_transformers import Decoder, Encoder

enc = Encoder(
    dim = 512,
    depth = 6,
    heads = 8,
    attn_num_mem_kv = 16 # 16 memory key / values
)

Memory Transformers

https://arxiv.org/abs/2006.11527

Proposes adding learned tokens, akin to CLS tokens, named memory tokens, that is passed through the attention layers alongside the input tokens. This setting is compatible with both encoder and decoder training.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    num_memory_tokens = 20, # 20 memory tokens
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Update: MetaAI researchers have found that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to attend to nothing).

Transformers Without Tears

https://arxiv.org/abs/1910.05895

They experiment with alternatives to Layer normalization and found one that is both effective and simpler. Researchers have shared with me this leads to faster convergence.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_scalenorm = True # set to True to use for all layers
    )
)

You can also use the l2 normalized embeddings proposed as part of fixnorm. I have found it leads to improved convergence, when paired with small initialization (proposed by BlinkDL). The small initialization will be taken care of as long as l2norm_embed is set to True

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    l2norm_embed = True,    # set this to True for l2 normalized embedding + small init
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Along the same lines of l2 normalized embeddings, Huggingface's 175B parameter BLOOM also places a layernorm right after the embeddings and just before the tokens enter the attention layers. This was corroborated by Yandex's 100B parameter YaLM to stabilize training.

It is recommended you either have either l2norm_embed or post_emb_norm set to True but not both, as they probably serve the same purpose.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    post_emb_norm = True,    # set this to True to layernorm summed token + pos embeddings
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Root Mean Square Layer Normalization

https://arxiv.org/abs/1910.07467

The authors propose to replace layer normalization with a simpler alternative, without mean centering and the learned bias. An investigative paper found this to be the best performing normalization variant. It was also used in Deepmind's latest large language models, Retro and Gopher.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_rmsnorm = True # set to true to use for all layers
    )
)

July 2023 A linear attention paper has experiments to show that removing the learned multiplicative gamma led to no performance degradation. This simplifies the RMS normalization to a satisfying l2norm(x) * sqrt(dim).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_simple_rmsnorm = True # set to true to use for all layers
    )
)

GLU Variants Improve Transformer

https://arxiv.org/abs/2002.05202

Noam Shazeer paper that explores gating in the feedforward, finding that simple gating with GELU leads to significant improvements. This variant also showed up in the latest mT5 architecture. You should always turn this on (I may eventually turn it on by default).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_glu = True # set to true to use for all feedforwards
    )
)

The PaLM language model also chose to use the Swish GLU variant. You can turn this on by setting two flags

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_swish = True, # set this to True
        ff_glu = True    # set to true to use for all feedforwards
    )
)

No Bias in Feedforward

Starting with PaLM, there begun a trend to remove biases from the transformer all together. Boris Dayma has run a number of experiments that showed removing biases from feedforwards led to increased throughput without any loss of accuracy. This was corroborated by yet another paper investigating transformer architecture variants.

You can turn off the feedforward bias as follows

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_no_bias = True  # set this to True
    )
)

ReLU²

https://arxiv.org/abs/2109.08668

This paper used neural architecture search and found an activation, Relu Squared, that is both simpler and performs better than GELU, in the autoregressive language model setting. I have confirmed this in my independent experiments. However, if one were using the GLU variant from above, GELU still performs better. Pending further corroboration.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_relu_squared = True
    )
)

Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection

https://arxiv.org/abs/1912.11637

This paper proposes an efficient way to sparsify attention by zeroing all dot-product query/key values not within the top k values. The show that this cheap method was as effective as other more expensive operations like sparsemax or entmax15. This technique comes with the cost of an extra hyperparameter (the top k values to keep). The paper recommends a value of k = 8

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_sparse_topk = 8 # keep only the top 8 values before attention (softmax)
    )
)

Talking-Heads Attention

https://arxiv.org/abs/2003.02436

A Noam Shazeer paper that proposes mixing information between heads pre and post attention (softmax). This comes with the cost of extra memory and compute.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_talking_heads = True  # turn on information exchange between attention heads
    )
)

One Write-Head Is All You Need

https://arxiv.org/abs/1911.02150

Yet another Noam Shazeer paper (he's a legend) that proposes to only have one head for the key / values, but multi-headed queries. This paper was largely ignored for a while, but recently validated at scale in AlphaCode as well as PaLM. It has the property of being memory efficient when decoding extremely large language models. You can use it with one keyword argument as shown below.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_one_kv_head = True
    )
)

This has been further generalized in a recent paper to allow for groups of query heads to attend to a single key / value head. You can use this by specifying the attn_kv_heads

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8,
        attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head
    )
)

Attention on Attention for Image Captioning

https://arxiv.org/abs/1908.06954

This paper proposes to add a gated linear unit at the end of the attention layer, further gated by the original queries. Although this is not widely used outside of visual question / answering, I suspect it should lead to improvements after seeing the success of the feedforward GLU variant.

Update: After some experimentation, I found this variant actually performs worse, but if it were to be modified to not concatenate the queries before gating, it performs much better. That is what we will be using in this repository.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_on_attn = True  # gate output of attention layer, by queries
    )
)

Intra-attention Gating on Values

Alphafold2 had a peculiar variant of attention where they gate the aggregated values with the input, presumably to have the block have more control over the update.

A quick test shows a small but noticeable improvement, on about the same order as attention on attention.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_gate_values = True  # gate aggregated values with the input
    )
)

Improving Transformer Models by Reordering their Sublayers

https://arxiv.org/abs/1911.03864

This paper proposes to break from the normal fixed pattern of alternating attention and feedforwards, but to have blocks of only attention at the beginning followed by blocks of feedforwards at the end. This was further corroborated by a paper by Nvidia that reduces the number of attention layers to be 1/3rd of the feedforwards without loss in performance.

The amount of interleaving is controlled by a "sandwich coefficient", which they found to be optimal at a value of 6.

You can experiment with this feature as shown below

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        sandwich_coef = 6  # interleave attention and feedforwards with sandwich coefficient of 6
    )
)

Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View

https://arxiv.org/abs/1906.02762

The authors propose to view the success of transformers from a dynamical systems point of view, and then proposes an improvement based on mathematics of that POV. Specifically, they propose to place the attention layer in between two feedforward layers. This was adopted by a paper using transformers for speech recognition, the Conformer.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        macaron = True  # use macaron configuration
    )
)

T5's Simplified Relative Positional Encoding

https://arxiv.org/abs/1910.10683

T5 is one of the most successful encoder / decoder transformer architectures trained to date. They invented a new simplified relative positional encoding based on learned bias values that are added to the attention matrix pre-softmax. This bias is shared and injected into each attention layer. I have decided to include this because it offers a cheap way to have relative positional encoding (superior to absolute positional), and I have read papers that suggest having positional encoding added to each layer (vs only before the first) is beneficial.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True  # adds relative positional bias to all attention layers, a la T5
    )
)

Residual Attention

https://arxiv.org/abs/2012.11747

This paper from Google proposes residualizing the pre-attention scores across all layers. At the cost of no extra parameters, they show improvement on top of regular attention networks. If you turn on this setting, be aware that the best results in the paper used post-normalization, in which case a learning warmup will be needed. The authors also reported that they could use a higher learning rate and get even better gains in the same amount of steps. (In the paper they use 2e-4 vs 1e-4 for vanilla transformer)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        pre_norm = False,       # in the paper, residual attention had best results with post-layernorm
        residual_attn = True    # add residual attention
    )
)

I also tried residualizing cross attention and may have noticed an improvement in convergence. You can try it by setting the cross_residual_attn keyword to True

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    dec_cross_residual_attn = True     # residualize cross attention
)

Transformer-XL recurrence

You can also do Transformer-XL recurrence, by simply passing in a max_mem_len in the TransformerWrapper class, and then making sure your Decoder has rel_pos_bias (or rotary_pos_emb) set to True.

Then, you can retrieve the memories at each step with the return_mems keyword and pass it to the next iteration.

import torch
from x_transformers import TransformerWrapper, Decoder

model_xl = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 512,
    max_mem_len = 2048,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True
    )
)

seg1 = torch.randint(0, 20000, (1, 512))
seg2 = torch.randint(0, 20000, (1, 512))
seg3 = torch.randint(0, 20000, (1, 512))

logits1, mems1  = model_xl(seg1, return_mems = True)
logits2, mems2  = model_xl(seg2, mems = mems1, return_mems = True)
logits3, mems3  = model_xl(seg3, mems = mems2, return_mems = True)

Setting up the logic for training and sampling from transformer xl can be a bit overwhelming. This repository offers a simple wrapper that should make this easy, with the XLAutoregressiveWrapper.

# pass in the above model_xl

xl_wrapper = XLAutoregressiveWrapper(model_xl)

seg = torch.randint(0, 20000, (1, 4096)).cuda()  # sequence exceeding max length, automatically segmented and memory managed

loss = xl_wrapper(seg)
loss.backward()

# then, after much training

prime = seg[:, :1024]   # if prime exceeds max length, memory will be caught up before generating

generated = xl_wrapper.generate(prime, 4096)  # (1, 4096)

Enhanced recurrence

This paper proposes a simple technique to enhance the range of Transformer-XL. They simply route the memory segment of a layer to the layer below it, for the next recurrent step. You can enable this by setting shift_mem_down = 1. You can also shift down arbitrary number of layers by setting this value to > 1.

import torch
from x_transformers import TransformerWrapper, Decoder

model_xl = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 512,
    max_mem_len = 2048,
    shift_mem_down = 1,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_pos_emb = True
    )
)

seg1 = torch.randint(0, 20000, (1, 512))
seg2 = torch.randint(0, 20000, (1, 512))
seg3 = torch.randint(0, 20000, (1, 512))

logits1, mems1  = model_xl(seg1, return_mems = True)
logits2, mems2  = model_xl(seg2, mems = mems1, return_mems = True) # mems1 of layer N are automatically routed to the layer N-1

Gated residual

https://arxiv.org/abs/1910.06764

The authors propose gating the residual connections in the transformer network and demonstrate increased stability and performance for Transformer-XL in a variety of reinforcement learning tasks.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    max_mem_len = 2048,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 16,
        gate_residual = True
    )
)

Rotary Positional Embeddings

Developed in Beijing, this new technique quickly gained interest in the NLP circles. In short, it allows you to endow the transformer with relative positional embeddings at the cost of no learned parameters. You apply a rotary operation to the queries and keys prior to their dot product in attention. The big idea is injecting positions through rotations.

Highly recommend that you have this turned on whenever you are working on an ordered sequence.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_pos_emb = True  # turns on rotary positional embeddings
    )
)

Update (12/2022): Rotary embedding has since been hugely successful, widely adopted in many large language models, including the largest in the world, PaLM. However, it has been uncovered in the ALiBi paper that rotary embeddings cannot length extrapolate well. This was recently addressed in a Microsoft research paper. They propose a way to unobtrusively add the same decay as in ALiBi, and found that this resolves the extrapolation problem. You can use it in this repository by setting rotary_xpos = True. Like ALiBi, it would enforce the attention to be local. You can set the receptive field with rotary_xpos_scale_base value, which defaults to 512

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_xpos = True   # modified rotary to extrapolate well beyond length at which it was trained
    )
)

Dynamic Positional Bias

This technique bears roots from the field of vision transformers, where researchers are trying to have relative positions generalize to larger resolutions (without having to retrain the entire network). It was used in two recent papers, CrossFormer, as well as SwinV2.

Charles Foster first tried this for a language model, and found that it works. Later on Eric Engelhart produced experimental results that show the same type of extrapolation holds, even for 1d sequences.

Eric trained at sequence lengths of 128, and showed that it generalized well to 1024. In addition, he showed that linear positions was better than log (used in SwinV2), for language.

Linear distances

Log distances

Negative control - Sinusoidal

More of Eric's experimental results can be found here

You can use this type of relative position if you wish to train at smaller sequence lengths and have it generalize to longer ones, for both autoregressive and bidirectional models.

Update: First place RNA folding using dynamic positional bias

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        dynamic_pos_bias = True,                # set this to True
        dynamic_pos_bias_log_distance = False   # whether to use log distance, as in SwinV2
    )
)

ALiBi Positional Embedding

This paper proposes to simply apply a static linear bias to the attention matrix. The authors show this is not only effective as a relative positional encoding, but also allows the attention net to extrapolate to greater sequences length than what it was trained on, for autoregressive language models.

This repository also offers a bidirectional variant (nonsymmetric), proposed by the authors here. However, this is untested. If you need bidirectional length extrapolation, the safest option would be Dynamic Position Bias

Update: It may be that ALiBi enforces a strong local attention across the heads, and may hinder it from attending at distances greater than 1k. To avoid any issues with global message passing, I've decided to introduce another hyperparameter alibi_num_heads, so one can specify less heads for the ALiBi bias

Update: There are reports that ALiBi outperform Rotary embeddings for pretraining and downstream fine-tuning.

Update: New paper shows that no positional embedding can length extrapolate even than explicit ones

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        alibi_pos_bias = True, # turns on ALiBi positional embedding
        alibi_num_heads = 4    # only use ALiBi for 4 out of the 8 heads, so other 4 heads can still attend far distances
    )
)

Shifted Tokens

An independent researcher has found that shifting a subset of the feature dimension along the sequence dimension by 1 token helps with convergence (Time-mixing). I have tested this for the autoregressive case and can confirm that it leads to greatly improved convergence. This also lines up with the results of some papers in the vision domain.

To use it, simply set shift_tokens = 1 (or to whatever number of shifts you desire). The feature dimension will be divided by shift_tokens + 1 and then each chunk will be shifted [0, shift_tokens] respectively

Update: new experiments by @sdtblck suggests this may only work for character-level training

Update: after more experiments, it seems that in the context of BPE encoding, with rotary turned on, there is no benefit to shifting. for character-level training, shifting may still improve a tiny bit

Update: When doing BPE encoded tokens, it seems that shift of 2 will bottleneck the dimensions (divided by 5). It is recommended you always do a shift of 1, unless if you are working with character level.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        shift_tokens = 1
    )
)

If you want finer control over how much is shifted per block (whether attention or feedforward), simply pass in a tuple of size that is equal to the number of layers.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        shift_tokens = (1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0) # 12 blocks, attention and feedforward alternating, with progressively less shifting
    )
)

Sandwich Norm

This technique first made an appearance in the CoqView paper, a Chinese version of the famous text-to-image transformer DALL-E. They propose, when using pre-layernorm, to add an extra layernorm to all the branch outputs. I have found this to be very effective for a number of projects, when facing instability during training.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        sandwich_norm = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

ResiDual

This Microsoft paper proposes yet another normalization configuration, combining both pre and post layernorm. They claim this hybridization reduces representation collapse (known to be an issue with pre-layernorm with increasing depth), while maintaining stability and reducing vanishing gradients (issues with post-layernorm). Initial experiments on my end show it to work no worse than pre-layernorm or sandwich norm. More study needed by the public to see if this is actually a winning technique.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        resi_dual = True,               # set this to True
        resi_dual_scale = 0.1           # in appendix, they said on fp16 the prenorm residual is prone to overflow. they claim by scaling it at each layer by a factor, it would prevent the overflow, and keep results the same (as layernorms are invariant to scaling of the input)
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Normformer

This paper uncovers an issue with pre-norm transformers where gradients are mismatched between the early and later layers. They propose 4 changes, of which I will be offering 3.

The first change is to offer per head scaling after aggregating the values in attention. My experiments show a slight improvement in convergence.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_head_scale = True  # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

The second change is an extra layernorm right after the activation in the feedforward. I have also verified a slight improvement, at the cost of extra compute.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_post_act_ln = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

For the residual scaling, you simply have to set scale_residual = True. I have noticed slight improvements, but occasional instability as well, so use with caution.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        scale_residual = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

The last change is a layernorm right after the outwards projection in attention. This is actually identical to the sandwich norm proposed by the Coqview paper, so you can use this by simply setting sandwich_norm = True, although it would also add it to the feedforward layer.

Cosine Sim Attention

This paper proposes to l2 normalize the queries and keys along the head dimension before the dot product (cosine similarity), with the additional change of the scale being learned rather than static. The normalization prevents the attention operation from overflowing, and removes any need for numerical stability measures prior to softmax. Both are perennial problems when training transformers.

This was validated at scale recently by the training of a 3B parameter vision transformer. The SwinV2 paper also proposes to change the pre-layernorm to a post-layernorm for further stability.

I have validated that this works just as well as dot product attention in an autoregressive setting, if one were to initialize the temperature as proposed in the QK-norm paper (as a function of the sequence length).

This flavor of attention also has a connection to sparse distributed memory. [youtube talk]

Update: I have discovered a way to remove the learned temperature altogether, by grouping the feature dimension and doing l2-normalization on each group. This allows the queries and keys to have a similarity that is upper bounded by the number of groups. A group size of 8 or 16 was sufficient in my tests. Decided to name this technique "Grouped QK Normalization". The drawback is that I believe an attention head dimension 32 is too small to use this tactic (a dimension often used in vision)

Update 2: Tero Karras has successfully used cosine sim attention in a new paper.

You can use it as follows

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_qk_norm = True,       # set this to True
        attn_qk_norm_groups = 8    # number of groups in the feature dimension for l2norm, similarity scores will be bounded between [-group, group]. determines how sharp the attention can be
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Another update: Simply scaling the cosine similarity (group of 1) with a fixed constant (10) may work too

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_qk_norm = True,       # set to True
        attn_qk_norm_scale = 10    # new scale on the similarity, with groups of 1
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

QK RMSNorm

Update: Google Brain has proven out something similar to cosine sim attention in a 22B parameter model. In their papers, they have analysis showing that the normalization resulted in not only extra stability, but also better results in the end (due to less need to adjust learning rate when increasing parameter count).

We are nearing the point of wiping out a source of transformer training instability with one simple intervention, in my opinion. The only slight difference in the paper is that they still have a learned scale across the feature dimension (per use of rmsnorm). Not sure how critical this is, but just to make sure we don't miss anything, I will include this here. You can use this by setting qk_norm_dim_scale = True

Update: Counterpoint from Tim Dettmers

Update 2: Counter to Tim's assertion that outliers are needed, and potentially even some solutions

Update 3: Used by 8B parameter LLM successfully

Update 4: a MetaAI group found that they can alleviate outliers by adding register tokens, also known as memory tokens from earlier literature (Burtsev et al). Perhaps what should be tried next is see if qk norm can be improved in the presence of memory tokens.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8,
        attn_qk_norm = True,
        attn_qk_norm_dim_scale = True # set this to True, in addition to `attn_qk_norm = True`
    )
)

x = torch.randint(0, 256, (1, 1024))
model(x)

Turning off absolute positional embedding

A number of papers have hinted that causal transformers (Decoder) can learn absolute positions in the absence of added embeddings of any sort. This was recently thoroughly investigated here. You can turn off the absolute positional embedding by setting use_abs_pos_emb = False in the TransformerWrapper

Given PaLM, the trend going forward may be to forgo absolute positional embedding (again, for causal transformers only), and add relative positional embeddings with RoPE, ALiBi, etc.

Update: This paper shows that in the absence of any engineered absolute or relative positional embeddings, decoders can generate implicit positions, and even length generalize better than solutions of the past. They were unaware of dynamic positional bias, however.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    use_abs_pos_emb = False,   # set this to False
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Forgetful Causal Mask

This paper shows convincing results that one can combine masking (from masked language modeling) with autoregressive training, leading to significantly better results.

You can use this by setting the mask_prob on the AutoregressiveWrapper class

import torch
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

model = AutoregressiveWrapper(
    model,
    mask_prob = 0.15  # in paper, they use 15%, same as BERT
).cuda()

# mock data

x = torch.randint(0, 20000, (1, 1024)).cuda()

# derive cross entropy loss, masking all taken care of

loss = model(x)
loss.backward()

Miscellaneous

Cross Attention

import torch
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim = 512, depth = 6)
model = CrossAttender(dim = 512, depth = 6)

nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()

neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()

encoded_neighbors = enc(neighbors, mask = neighbor_masks)
model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neighbor_masks) # (1, 1, 512)

Continuous Embeddings

import torch
from x_transformers import ContinuousTransformerWrapper, Decoder

model = ContinuousTransformerWrapper(
    dim_in = 32,
    dim_out = 100,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

x = torch.randn((1, 1024, 32))

model(x) # (1, 1024, 100)

You can also train a transformer that accepts continuous values autoregressively easily, in the same scheme as done successfully in this paper

import torch
from x_transformers import ContinuousTransformerWrapper, Decoder
from x_transformers import ContinuousAutoregressiveWrapper

model = ContinuousTransformerWrapper(
    dim_in = 777,
    dim_out = 777,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

# wrap it with the continuous autoregressive wrapper

model = ContinuousAutoregressiveWrapper(model)

# mock data

x = torch.randn((1, 1024, 777))
mask = torch.ones(1, 1024).bool()

# train on a lot of data above

loss = model(x, mask = mask)
loss.backward

# then generate

start_emb = torch.randn(1, 777)
generated = model.generate(start_emb, 17) # (17, 777)

xVal - Continuous and Discrete

This is promising work that resulted from the collaboration across many institutes (collectively known as Polymathic AI). They found that by offering a continuously scaled number token to the transformer, the transformer was able to generalize arithmetic and forecasting tasks better than the alternative encoding schemes.

This is corroborated by some prior work

import torch

from x_transformers import (
    Decoder,
    XValTransformerWrapper,
    XValAutoregressiveWrapper
)

model = XValTransformerWrapper(
    num_tokens = 4,
    numerical_token_id = 3,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

# wrap it with the xval autoregressive wrapper

model = XValAutoregressiveWrapper(model)

# mock data

ids = torch.randint(0, 4, (1, 777))
nums = torch.randn(1, 777)
mask = torch.ones(1, 777).bool()

# train on a lot of data above

loss = model(ids, nums, mask = mask)
loss.backward()

# then generate

start_ids = torch.randint(0, 4, (1, 1))
start_nums = torch.randn(1, 1)

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)

# (1, 17), (1, 17), (1, 17)

# discrete, continuous, mask for discrete / continuous

Citations

@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@inproceedings{Zoph2022STMoEDS,
    title   = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
    author  = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
    year    = {2022}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@misc{burtsev2020memory,
    title   = {Memory Transformer}, 
    author  = {Mikhail S. Burtsev and Grigory V. Sapunov},
    year    = {2020},
    eprint  = {2006.11527},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{zhao2019explicit,
    title   = {Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection}, 
    author  = {Guangxiang Zhao and Junyang Lin and Zhiyuan Zhang and Xuancheng Ren and Qi Su and Xu Sun},
    year    = {2019},
    eprint  = {1912.11637},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{correia2019adaptively,
    title   = {Adaptively Sparse Transformers},
    author  = {Gonçalo M. Correia and Vlad Niculae and André F. T. Martins},
    year    = {2019},
    eprint  = {1909.00015},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{press2020improving,
    title   = {Improving Transformer Models by Reordering their Sublayers}, 
    author  = {Ofir Press and Noah A. Smith and Omer Levy},
    year    = {2020},
    eprint  = {1911.03864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{lu2019understanding,
    title   = {Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View}, 
    author  = {Yiping Lu and Zhuohan Li and Di He and Zhiqing Sun and Bin Dong and Tao Qin and Liwei Wang and Tie-Yan Liu},
    year    = {2019},
    eprint  = {1906.02762},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{ke2020rethinking,
    title     = {Rethinking Positional Encoding in Language Pre-training},
    author    = {Guolin Ke and Di He and Tie-Yan Liu},
    year      = {2020},
    eprint    = {2006.15595},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{huang2019attention,
    title   = {Attention on Attention for Image Captioning},
    author  = {Lun Huang and Wenmin Wang and Jie Chen and Xiao-Yong Wei},
    year    = {2019},
    eprint  = {1908.06954},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{raffel2020exploring,
    title   = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}, 
    author  = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
    year    = {2020},
    eprint  = {1910.10683},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{martins-etal-2020-sparse,
    title   = "Sparse Text Generation",
    author  = "Martins, Pedro Henrique  and
        Marinho, Zita  and
        Martins, Andr{\'e} F. T.",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    month   = nov,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.emnlp-main.348"
}
@misc{he2020realformer,
    title   = {RealFormer: Transformer Likes Residual Attention},
    author  = {Ruining He and Anirudh Ravula and Bhargav Kanagal and Joshua Ainslie},
    year    = {2020},
    eprint  = {2012.11747},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{carion2020endtoend,
    title   = {End-to-End Object Detection with Transformers},
    author  = {Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
    year    = {2020},
    eprint  = {2005.12872},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}
@misc{parisotto2019stabilizing,
    title     = {Stabilizing Transformers for Reinforcement Learning},
    author    = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year      = {2019},
    eprint    = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{narang2021transformer,
    title       = {Do Transformer Modifications Transfer Across Implementations and Applications?},
    author      = {Sharan Narang and Hyung Won Chung and Yi Tay and William Fedus and Thibault Fevry and Michael Matena and Karishma Malkan and Noah Fiedel and Noam Shazeer and Zhenzhong Lan and Yanqi Zhou and Wei Li and Nan Ding and Jake Marcus and Adam Roberts and Colin Raffel},
    year        = {2021},
    eprint      = {2102.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{zhang2019root,
    title   = {Root Mean Square Layer Normalization},
    author  = {Biao Zhang and Rico Sennrich},
    year    = {2019},
    eprint  = {1910.07467},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Qin2023ScalingTT,
    title   = {Scaling TransNormer to 175 Billion Parameters},
    author  = {Zhen Qin and Dong Li and Weigao Sun and Weixuan Sun and Xuyang Shen and Xiaodong Han and Yunshen Wei and Baohong Lv and Fei Yuan and Xiao Luo and Y. Qiao and Yiran Zhong},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:260203124}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Chen2023ExtendingCW,
    title   = {Extending Context Window of Large Language Models via Positional Interpolation},
    author  = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian},
    year    = {2023}
}
@inproceedings{Sun2022ALT,
  title     = {A Length-Extrapolatable Transformer},
  author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
  year      = {2022}
}
@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{csordás2021devil,
    title   = {The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers},
    author  = {Róbert Csordás and Kazuki Irie and Jürgen Schmidhuber},
    year    = {2021},
    eprint  = {2108.12284},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{so2021primer,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling}, 
    author  = {David R. So and Wojciech Mańke and Hanxiao Liu and Zihang Dai and Noam Shazeer and Quoc V. Le},
    year    = {2021},
    eprint  = {2109.08668},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, 
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Haviv2022TransformerLM,
    title   = {Transformer Language Models without Positional Encodings Still Learn Positional Information},
    author  = {Adi Haviv and Ori Ram and Ofir Press and Peter Izsak and Omer Levy},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.16634}
}
@article{chowdhery2022PaLM,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Chowdhery, Aakanksha et al},
    year    = {2022}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@article{Ainslie2023GQATG,
    title   = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
    author  = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.13245},
    url     = {https://api.semanticscholar.org/CorpusID:258833177}
}
@misc{schlag2020enhancing,
    title   = {Enhancing the Transformer with explicit relational encoding for math problem solving},
    author  = {Imanol Schlag and Paul Smolensky and Roland Fernandez and Nebojsa Jojic and J{\"u}rgen Schmidhuber and Jianfeng Gao},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1xfElrKPr}
}
@article{Liu2022FCMFC,
    title   = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13432}
}
@inproceedings{Huang2016DeepNW,
    title   = {Deep Networks with Stochastic Depth},
    author  = {Gao Huang and Yu Sun and Zhuang Liu and Daniel Sedra and Kilian Q. Weinberger},
    booktitle = {European Conference on Computer Vision},
    year    = {2016}
}
@inproceedings{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    booktitle = {International Conference on Machine Learning},
    year    = {2022}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Xie2023ResiDualTW,
  title     = {ResiDual: Transformer with Dual Residual Connections},
  author    = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2304.14802}
}
@inproceedings{Dehghani2023ScalingVT,
    title   = {Scaling Vision Transformers to 22 Billion Parameters},
    author  = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
    year    = {2023}
}
@article{Beyer2022BetterPV,
    title   = {Better plain ViT baselines for ImageNet-1k},
    author  = {Lucas Beyer and Xiaohua Zhai and Alexander Kolesnikov},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01580}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@misc{bloc97-2023
    title   = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.},
    author  = {/u/bloc97},
    url     = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/}
}
@inproceedings{Zoph2022STMoEDS,
    title   = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
    author  = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
    year    = {2022}
}
@article{Lan2019ALBERTAL,
    title   = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author  = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1909.11942},
    url     = {https://api.semanticscholar.org/CorpusID:202888986}
}
@inproceedings{Li2022ContrastiveDO,
    title   = {Contrastive Decoding: Open-ended Text Generation as Optimization},
    author  = {Xiang Lisa Li and Ari Holtzman and Daniel Fried and Percy Liang and Jason Eisner and Tatsunori Hashimoto and Luke Zettlemoyer and Mike Lewis},
    booktitle = {Annual Meeting of the Association for Computational Linguistics},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:253157949}
}
@inproceedings{OBrien2023ContrastiveDI,
    title   = {Contrastive Decoding Improves Reasoning in Large Language Models},
    author  = {Sean O'Brien and Mike Lewis},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:261884427}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Golkar2023xValAC,
    title   = {xVal: A Continuous Number Encoding for Large Language Models},
    author  = {Siavash Golkar and Mariel Pettee and Michael Eickenberg and Alberto Bietti and M. Cranmer and G{\'e}raud Krawezik and Francois Lanusse and Michael McCabe and Ruben Ohana and Liam Parker and Bruno R{\'e}galdo-Saint Blancard and Tiberiu Teşileanu and Kyunghyun Cho and Shirley Ho},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263622222}
}
@article{Rafailov2023DirectPO,
    title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
    author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.18290},
    url     = {https://api.semanticscholar.org/CorpusID:258959321}
}

solve intelligence... then use that to solve everything else. - Demis Hassabis

x-transformers's People

Contributors

adrian-spataru avatar anthonyzhou-1 avatar apage43 avatar chogamy avatar cifkao avatar frederikfab avatar gurvindersingh avatar hadaev8 avatar ilya16 avatar jbcdnr avatar jstjohn avatar kcarnold avatar lucidrains avatar ncoop57 avatar notprime avatar pfeatherstone avatar ramesharvind avatar stas-sl avatar taemincho avatar tmphex avatar wangcongcong123 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

x-transformers's Issues

Any tips for speeding up generation?

Because of the autoregressive nature of Transformers, I know that they are fairly slow when generating new sequences from scratch, but I was wondering if you had any tips or tricks on how to do faster inference or to know if you had plans for maybe adding some of the tricks to avoid full computation, like the ones used by Huggingface https://huggingface.co/blog/accelerated-inference

Thank you very much for your amazing work!

`rotary_pos_emb = True` causes an exception to be raised when the model is pickled.

import torch
from x_transformers import ContinuousTransformerWrapper, Encoder

model = ContinuousTransformerWrapper(
    max_seq_len=128,
    attn_layers = Encoder(
        dim = 32,
        depth = 2,
        heads = 1,
        rotary_pos_emb = True # This line is the problem.
    )
)

with open("qwe.nnreg",'wb') as f: torch.save(model,f)
Traceback (most recent call last):
  File "c:/Users/Marko/Source/Repos/The Spiral Language/Spiral Compilation Tests/cython_experiments/ui_holdem8 (transformers)/script1.py", line 15, in <module>
    with open("qwe.nnreg",'wb') as f: torch.save(model,f)
  File "C:\Users\Marko\anaconda3\lib\site-packages\torch\serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "C:\Users\Marko\anaconda3\lib\site-packages\torch\serialization.py", line 484, in _save
    pickler.dump(obj)
AttributeError: Can't pickle local object 'always.<locals>.inner'

I'll have to skip using the rotary embedding until this is resolved. Without the highlighted line the pickling works fine.

How to make inference fast (by adding caching of key / values)

How can we implement the below caching technique in the code?

That would be awesome.
What I have tried to speed up the inference in my custom implementations for autoregressive self-attention is caching the output of the self-attention at timestep T and then, in timestep T+1, passing the full keys/values but only passing the last element of the query sequence, then getting the output and concatenating it with the cache, that way each query can pay attention to the full previous sequence but we don't need to compute attention for all the previous queries when we only need the output at T+1
It looks something like this:
Captura de pantalla 2021-03-12 a las 17 01 32
But I only achieved a x3 speedup 🤔

I actually needed to perform autoregressive inference in a very large dataset, and it was taking more than 1 day even with the above speedup. I am currently doing some weird custom stuff, keeping the Transformer attention layers but replacing the self-attention layers with LSTMs, which are way faster at generating sequences token by token, and with that I achieve the x10 speedup that I needed.

Originally posted by @pabloppp in #21 (comment)

No return_embeddings in ViTransformerWrapper

For image -> caption example in README

encoder = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)

decoder = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
cross_attend = True
)
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

There is no field "return_embeddings"

with `mask` output is `nan`

I am trying to run image captioning example with the mask for caption. Without mask forward pass runs properly but with mask I get the nan. You can reproduce by running following simple example.

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))
mask = torch.ones_like(caption, dtype=torch.bool)
encoded = encoder(img)
decoder(caption, context = encoded, mask = mask) 

output:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], grad_fn=<UnsafeViewBackward>)

Hopfield Nets for memory purpose in x-transformers?

Hi, this x-transformers repo. is having alot of very useful features all at one place, though I was thinking if Modern hopfields may result in an increase in performance? The implementation is given here https://github.com/ml-jku/hopfield-layers
Though I couldn't understand how to use it for memory purposes.
What are your views about it? Are modern hopfields any useful as associative memory nets ? and if so, how should they be implemented? cause just adding them like lookup-layer didn't gave any special performance improvement.

Pay Attention When Required

First, thanks for the great repo!

Here's a recent paper from NVIDIA: https://arxiv.org/pdf/2009.04534v2.pdf
Seems like a similar concept to Sandwich, but faster, simpler, and near identical perplexity.

Edit: Oh, I see you mention it already. Is there a parameter exposed for it already?

This was further corroborated by a paper by Nvidia that reduces the number of attention layers to be 1/3rd of the feedforwards without loss in performance.

Error with last update

  File "/usr/local/lib/python3.6/dist-packages/x_transformers/x_transformers.py", line 820, in __init__
    super().__init__(causal = False, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/x_transformers/x_transformers.py", line 732, in __init__
    residual_fn = Residual(dim)
TypeError: __init__() takes 1 positional argument but 2 were given

Different Inference Resuts?

I have trained a transformer model, and noticed something strange,
For same input, Final decoder output shape varies and few final decoded tokens are different for each inference., Example:
First Infernece:
Output from Decoder torch.Size([1, 178]) |
tensor([[ 87, 267, 11, 417, 319, 333, 290, 286, 383, 280, 418, 353, 336, 404,
286, 290, 292, 542, 365, 364, 493, 445, 52, 53, 560, 505, 40, 41,
354, 400, 291, 319, 408, 269, 51, 32, 268, 11, 291, 319, 408, 269,
277, 505, 32, 23, 24, 268, 11, 271, 98, 292, 418, 286, 290, 301,
560, 291, 319, 408, 269, 41, 32, 889, 268, 11, 271, 98, 292, 418,
286, 290, 301, 334, 653, 291, 319, 408, 269, 52, 268, 11, 291, 280,
354, 280, 313, 269, 277, 268, 11, 291, 286, 325, 280, 334, 319, 418,
280, 610, 55, 522, 450, 74, 85, 88, 374, 820, 578, 780, 269, 51,
52, 53, 268, 11, 326, 505, 40, 41, 269, 277, 268, 11, 935, 748,
269, 51, 32, 268, 11, 748, 269, 39, 32, 23, 24, 268, 11, 328,
576, 733, 326, 748, 269, 41, 32, 277, 906, 268, 11, 328, 576, 733,
17, 740, 595, 421, 283, 569, 287, 748, 269, 52, 34, 277, 906, 268,
11, 417, 292, 325, 301, 505, 576, 733, 269, 2]], device='cuda:0')

Second time infer for same input:
Output from Decoder torch.Size([1, 183]) |
tensor([[ 87, 267, 11, 417, 319, 333, 290, 286, 383, 280, 418, 353, 336, 404,
286, 290, 292, 542, 365, 364, 493, 269, 51, 52, 53, 268, 11, 560,
505, 40, 41, 354, 400, 291, 319, 408, 269, 51, 32, 268, 11, 291,
319, 408, 269, 277, 505, 32, 23, 24, 268, 11, 271, 98, 292, 418,
286, 290, 301, 560, 291, 319, 408, 269, 41, 32, 889, 268, 11, 271,
98, 292, 418, 301, 334, 653, 370, 291, 319, 408, 269, 52, 268, 11,
291, 280, 354, 280, 313, 269, 277, 268, 11, 291, 286, 325, 280, 334,
319, 418, 280, 610, 55, 522, 450, 74, 85, 88, 374, 820, 578, 780,
269, 51, 52, 53, 268, 11, 326, 505, 40, 41, 269, 277, 268, 11,
935, 748, 269, 51, 32, 268, 11, 748, 269, 39, 32, 23, 24, 268,
11, 328, 576, 733, 326, 748, 269, 41, 32, 277, 906, 268, 11, 328,
576, 733, 17, 740, 595, 421, 283, 569, 287, 748, 269, 52, 268, 11,
610, 269, 277, 889, 268, 11, 417, 292, 325, 301, 505, 576, 733, 269,
2]], device='cuda:0')

Is this behaviour normal for decoder?

Transformer-XL recurrence different from how it is presented in the paper

The current Transformer-XL implementation uses attention length equal to the input segment length plus the memory length, while in the paper the attention length is presented as independent from the input length or the memory length. This behavior is unwanted since you can't benefit from the extended receptive field presented in figure 2. https://arxiv.org/pdf/1901.02860.pdf
A solution could be to use an attention mask providing a further parameter to the model that automatically generates the attention mask. A snippet of code of how it could be implemented:

if self.causal:
    i, j = dots.shape[-2:]
    r = torch.arange(i, device = device)
    distance = rearrange(r, 'j -> () () () j') - rearrange(r, 'i -> () () i ()')
    mask = distance > 0
    if self.att_len:
        mask_2 = distance < self.att_len
        mask = torch.logical_and(mask, mask_2)
        del mask_2
    mask = F.pad(mask, (j - i, 0), value = False)
    dots.masked_fill_(mask, mask_value)
    del mask

Understanding the key and value transmission from the encoder output to the decoder

Once again, thank you Phil for the amazing work and time you put into your work. I appreciate it! Two questions though:

  1. Do I understand correctly that the keys and values that are passed from the encoder output to every layer of the decoder is nothing else but one single entity/tensor of shape (batch_size, num_tokens, dim_embedding) and it is only inside of the decoder layers that this tensor is then split into a key and a value tensor by means of multiplications with respective learnable tensors? So there really is only one single tensor that is passed to the decoder as opposed to many illustrations, cf.

transformer_decoding_2

The split is performed in every decoder layer anew, correct?

  1. I am trying to understand the resemblance/analogy to classical convnets in which after several layers of convolutions and pooling one usually uses the embeddings with reduced spatial dimension but increased feature/channel dimension for some further downstream tasks, cf. VGG architecture.

Is this akin to the encoder part of the Vision Transformer?
the output of the encoder are separate feature vectors for every token/image patch which seems to be different from the resulting embeddings of convnets...

  1. Talking again about ViT, could a decoder be thought of as kind of an upsampling/generative counterpart (cf. GANs) of a downsampling part, i.e. the encoder? I try to relate encoder and decoder and their use cases to the more classical architectures in Deep Learning.

Thank you in advance!

My Experience with X-Transformers

I have run some models in the past weeks. All of them being encoder-decoder transformers.
I am not sure where is the right place to write stuff like this, but I'll write them here for now.

Word of Caution: My particular use case is not NLP. But its a corpus with around 200M Tokens and vocab_size of 1k

Transformers Without Tears
Researchers have shared with me this leads to faster convergence.

This did lead to faster convergence in the beginning, but performance was slightly worse. (Ran 2 Runs)

GLU Variants Improve Transformer

Took longer to converge and wasn't better (Ran 2 Runs)

Rezero Is All You Need

Didn't converged for me and became after a while NaN (Ran 2 Runs)

T5's Simplified Relative Positional Encoding

Converged quicker and was better, even when wrongly configured (used max_distance 128, instead of 512, which is my max_seq_len)
For Seq_len of 512, a bucket_size of 64, was better than default 32. (One Run each)

Talking-Heads Attention

Didn't noticed anything for my usecase ( 1 Run only)

Reason for doing partial rotary embedding?

In a recent commit only half the head vector is "rotated", could this improve the overall performance? Thanks!
4b395ab

-    self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else always(None)
+    rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)

Since version 12.4 the first example in the readme gives an error: KeyError: 'emb_dropout'

image

Since version 12.4 the first example in the readme gives an error: KeyError: 'emb_dropout'

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))
tgt_mask = torch.ones_like(tgt).bool()

loss = model(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask) # (1, 1024, 512)
loss.backward()

Simple feature request: transformers for continuous inputs

I think it would be useful to add an option to the TransformerWrapper, or perhaps to make a new Wrapper type, that does not use embedding layers, so that the inputs are real-valued vectors. This would allow to use x-transformers for tasks with continuous inputs. For example here they use transformers in that way, and also most of the times transformers are applied to regression tasks.

Note that in this case, one then just talks about input and output dimension, and not number of tokens.

Perhaps the cleanest way to do this is to make a new Wrapper type that works with continuous vector inputs, and then make TransformerWrapper use this ContinuousTransformerWrapper inside it, with input dimension being the embeding dimension, and output dimension being num_tokens. Hope this makes sense!

Importing AttentionLayers

I want to have two encoders (not seq2seq) and seems like i cant use default abstractions.
Would be nice to be able to import AttentionLayers class from lib.

Scaling in normalization modules

Hi, thank you very much for publishing this awesome repository :) After studying the recent changes, I was confused with the introduction of self.scale = dim ** -0.5 into ScaleNorm and RMSNorm in this commit.

If I understand the code correctly, the modules multiply the normalized variable by dim ** 0.5 (these two lines divide the variable by dim ** -0.5). Since both the queries and the keys are multiplied like this, the attention matrix is effectively multiplied by dim, which goes against the usual practice to multiply it by dim ** -0.5 (as you also do in the code).

I believe the normalized variables shouldn't be multiplied like this, what is the reason behind scaling it? Thank you very much for your time.

Technical question

@lucidrains Hey bro!

I think I finally got it to work properly and it shows good results with music. I am really enjoying your creation. Very good job. Thanks.

@lucidrains

I do not know where to post questions on your GitHub so I am posting here.

I wanted to ask you if you know what "multiple-embedding" maybe? Have you ever heard about such a thing?

From what I understand, it means that the transformer can accept several tokens at once as input. Ideally, each layer can accept tokens, so if you have 6 layers, you would therefore feed it 6 tokens of input.

It is sorta like distributed training as far as I understand but for one worker.

Does any of it make sense? Can you consider looking into it? I think it would be a great addition to your x-transformer. It would make it fast as hell and probably more capable as the model will be able to make connections between input tokens. I.e. in music, this would allow multiple instruments for example.

And a side question....what is the hype around the reformer? Yes, it is tiny and trains well, but it is not even close to GPT3 afaik, so I am really confused...Is it cuz it's Google or something?

Thanks.

Your time and responses will be very much appreciated.

Alex

how to update encoder-decoder model training parameters

Hi, I checked this example in which encoder & decoder are defined inside one model. And from the code I understood model train forward/backword pass.

How will forward/backword pass work for lets say different encoder & different decoder (like you mentioned in Image -> caption example)

Any starting point/example would be appreciated.

sequence length independent generation

Currently generation require passing sequence length to generate sequences of given length but say in tasks such as summary or translation, one doesn't know about the final sequence length. Currently I am trying to generate candidates with passing various lengths as work around. Also is it possible to add support for beam search method for generation in addition to current top_p/top_k methods.

Memory Efficiency w.r.t Sequence Length

I am a bit of a noob when it comes to transformers. If I want to encode a batch of N sequences of maximum length L, my understanding is that I do something like this:

from x_transformer import Encoder, TransformerWrapper
seqs = ['aba','cb','abcab']
N = len(seqs)
L = max(len(seq) for seq in seqs)
C = 3
padded_seqs = get_padded_seqs(seqs) # N x L long tensor
mask = get_seq_mask(seqs) # N x L boolean tensor
encoder = TransformerWrapper(num_tokens=C,max_seq_len=L,attn_layers=Encoder())
embeddings = encoder(padded_seqs,mask=mask,return_embeddings=True)

In this transformer implementation, would there be a difference in memory usage if all of the sequences were of length L (i.e. all the mask values were True)?

onnx model?

onnx conversion error:
triu operator not supported in line (for attention):
if.causal:
mask = torch.zeros((i,j)).triu_(j - i + 1).bool()

Any Help

Learned memory keyword change

I believe an earlier version may have used attn_num_mem_kv as the keyword argument for setting how many persistent memory vectors should be used. It looks like this is now num_mem_kv. The example in the readme has the older form.

AutoCast for mixed precision/fp16 fails?

I have tried to train the model using torch.cuda.amp.autocast() but the training doesn't seems to speeds up or memory usage remains same as with fp32 training.
also model size remain same with or without autocast
Can you help what could be the reason.
i also used huggingface Accelerate :https://github.com/huggingface/accelerate but cant achieve mixed precision.

[Question!] How to Inject Rotary Positional Embeddings in Linear Transformers

Hello Phil,

Do you mind how to inject the rotary positional embeddings into the linear transformers ?

import torch
from torch.nn import Module

from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
    EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map


class LinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.
    Given the queries, keys and values as Q, K, V instead of computing
        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
    we make use of a feature map function Φ(.) and perform the following
    computation
        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.
    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(LinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

Thanks!

Shared Embeddings

Sharing the token_emb between Encoder & Decoder is not by default. Lot of transformers like BART/T5 use a shared encoder/decoder embedding.

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    enc_num_memory_tokens = 0,
    
)

model.decoder.token_emb = model.encoder.token_emb

Would be this enough?

Furthermore the example for Encoder/Decoder in ReadMe doesn't work out of the box, it needs also a value for enc_num_memory_tokens

Benchmarking Transformers

(Co-Author of ReZero)
I noticed this repository and it's very cool that there are many transformer variants being implemented. I wonder if there exists a benchmark for all these experimental features. If not, it might be a useful to benchmark all these variants and their performance on a table (efficiency, pre-training performance on The Pile, for example).

There have been review papers out there that perform some benchmarking of Transformer variants, but they get outdated very quickly. PDFs aren't a great format for this type of thing. It'll be highly useful if there's a working Github that contains these benchmarks and people can collaborate on to "add their method".

I'm interested in discussing ideas/collaborating if others are.

Feature request for adding Memformer memory

So i've checked the memformer repository and this one. I think it would be good to add the memory from the memformer to this project also since it seems more general and better than the transformer XL memory.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.