Coder Social home page Coder Social logo

w_positional_embeddings_pytorch's Introduction

Positional Embeddings in PyTorch

Nomenclature

Nobody likes it, but obviously this same things have many slightly different names.

It consists of two words, the first word can be "position" or "positional", and the second "embedding" or "encoding". In this pakcage, it is called positional embedding.

In brief

Positional embedding is critical for a transformer to distinguish between permutations. However, the countless variants of positional embeddings make people dazzled. Positional embeddings can be awkward to understand and implement, sometimes taking the majority of space in your pytorch code.

The aim of this package is to build a collection of popular positional embedding modules and provide unified APIs. Ideally, positional embeddings could be isolated from the transformer architecture and be used in a plug-and-play manner, i.e. outside positional embedding modules, everything should be permutation invariant.

Besides, this package is meant to provide plain, easy and naive implementation.

APIs

After comparing several positional embeddings, I summarize the behavior of positional embedding into two APIs.

forward_input

Positional embeddings can directly integrate positional information into input (e.g. word embedding). This API integrates positional information into the input.

PositionalEmbedding.forward_input

"""Generate positional embedding to be added to the input.

Args:        

    input_: torch.Tensor: shape: [batch_size, max_length, embed_dim]
        The input tensor.

    positions: torch.Tensor, shape: [batch_size, max_length, input_]
        Absolute positions of input tokens.


Returns:
input_: torch.Tensor, shape: [batch_size, max_length, input_]
    A tensor with both input and positional information.
"""

forward_attn

Some implementations (especially relative positional embeddings) directly generate attention matrix from positional embeddings and add to the qk attention matrix, i.e. attention bias. Some implementations modify queries and keys before calculating attention matrix so that they possess positional information.

These facts represent the tight coupling between positional embedding and transformer and as a design choice, I decided to leave the responsibility of calculating attention matrix to positional embeddings.

PositionalEmbedding.forward_attn

"""Generate attention logits from queries, keys and their positions.

Args:

    q: torch.Tensor, shape: [batch_size, num_heads, q_max_length, head_dim]
        The query tensor.

    k: torch.Tensor, shape: [batch_size, num_heads, k_max_length, head_dim]
        The key tensor.
    
    positions_q: torch.Tensor, shape: [batch_size, q_max_length]
        Absolute positions of query tokens.
    
    positions_k: torch.Tensor, shape: [batch_size, k_max_length]
        Absolute positions of key tokens.


Returns:
attn_logits: torch.Tensor, shape: [batch_size, q_max_length, k_max_length]
    Attention logits (before padding mask, before softmax, and before scaling)
"""

I know we generally regard calculating attention matrix (qk similarity) as a characteristic step of a transformer module and a lot architectural modifications happen here. However, to isolate positional embedding from transformer, I have to make this decision. This means the O(n^2) complexity now belongs to positional embedding instead of transformers, and architectural modifications, such as sparse attention, additive attention, now must be reflected in positional embedding modules.

Basic usage

#To be added#

Supported positional embeddings

Sinusoidal positional embedding (SinusoidalPositionalEmbedding) in "Attention is all you need".

Learnable positional embedding (LearnedPositionalEmbedding) in BERT and GPT.

Relative positional embedding (TransformerXLPositionalEmbedding) in Transformer-XL.

Relative positional embedding (T5PositionalEmbedding) in T5.

Unified positional embedding (UnifiedPositionalEmbedding) in TUPE.

Relative positional embedding (EnformerPositionalEmbedding) in Enformer.

Installation

pip install positional-embeddings-pytorch

However, this package is highly experimental. It could serve as a reference for implementing and choosing positional embeddings, but I strongly discourage you directly throwing it into your code. Instead, users are expected to have prior knowledge about positional embeddings and check the code before using.

Future work

  • Positional embedding for decoder.
  • Positional embedding with memory.
  • Add support for [CLS] tokens.

(Current implementation only considers transformer encoder without memory, and does not support special tokens such as [CLS].)

References

pytorch/fairseq: Facebook AI Research Sequence-to-Sequence Toolkit written in Python. (github.com)

huggingface/transformers: ๐Ÿค— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. (github.com)

kimiyoung/transformer-xl (github.com)

guolinke/TUPE: Transformer with Untied Positional Encoding (TUPE). Code of paper "Rethinking Positional Encoding in Language Pre-training". Improve existing models like BERT. (github.com)

T5 relative positional embedding (github.com)

Author(s)

Yiming Qu.

Undergraduate at Tsinghua University. Biology and Data Science.

Research Intern at Microsoft Research. Computational Biology.

w_positional_embeddings_pytorch's People

Contributors

whatever60 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.