Coder Social home page Coder Social logo

xformer's Introduction

xformer

Simple and Flexible API for training transformer models on sequence to sequence tasks.

Installation

pip install git+https://github.com/apoorvnandan/xformer.git

Contents

Purpose:

Have a simple interface built on top of pytorch and tensorflow for modelling any sequence to sequence problems with transformer models. The models have built in functions for training and generating outputs. (only with greedy decoding at the moment) And we have callbacks for use case specific customizations.

Check our minimal examples below!

Current scope:

Input:

  • Case 1: Input is a sequence of token indices (shape: (batch_size,input_len), type: long)
  • Case 2: Input is a sequence of feature vectors (shape: (batch_size,input_len,input_dim), type: float)

Output:

  • Output is a sequence of token indices (shape: (batch_size,input_len), type: long)

This covers a lot of popular applications of sequence to sequence models:

  • Machine Translation
  • Abstractive Summarisation
  • Automatic Speeech Recognition
  • Video captioning
  • Chatbots

We also provide built-in callbacks that help with the above use cases. You can also write a custom callback to execute your code at any point in the training loop.

Minimal Examples

Machine Translation: (with pytorch API)

from my_data import (
    src_tokenizer, 
    trg_tokenizer, 
    train_loader, 
    test_loader
)

# data loader can be any iterable that returns the following:
print(next(iter(train_loader)))  # {"enc_inp": <input batch (N,max_src_len,H)>, "dec_out": <target batch (N,max_trg_len)>}

from xformer import Transformer
from xformer.callbacks import BleuScore

model = Transformer.from_config(
    num_classes=trg_tokenizer.size(),
    embed_input=True,
    input_vocab_size=src_tokenizer.size(),
    src_pad_idx=src_tokenizer.pad_token_idx,
    trg_pad_idx=trg_tokenizer.pad_token_idx
)
bleu_cb = BleuScore(
    trg_tokenizer, test_loader, trg_tokenizer.bos_token_idx, trg_tokenizer.eos_token_idx
)
model.fit(train_loader, n_epochs=2, callbacks=[bleu_cb])
Epoch 1/2
1070/1070 [==============================] - 279s 261ms/step - loss: 0.7809
test bleu score: 84.12
Epoch 2/2
1070/1070 [==============================] - 266s 248ms/step - loss: 0.1499
test bleu score: 88.49

Speech to text: (With tensorflow API)

from my_data import ds, val_ds
from xformer.tf import Transformer
from xformer.tf.callbacks import DisplayOutputs

# ds, val_ds are tf.data.Dataset objects which generate a batch like this
for i in ds.take(1):
    print(i['src'].shape)  # (4, 59, 80)
    print(i['trg'].shape)  # (4, 13)

model = Transformer(
    input_type="feats",
    nvocab=1000,
    ninp=80,
    nhid=64,
    nhead=2,
    nff=128,
    src_maxlen=59,
    trg_maxlen=12,
    nlayers=2,
    nclasses=30,
)
for i in val_ds.take(1):
    batch = i  # Use the first batch of validation set to display outputs

# vocabulary to convert preedicted indices to characters
idx_to_char = ["-"] + [chr(i + 96) for i in range(1, 27)] + ["<", ">", " "]
cb = DisplayOutputs(batch, idx_to_char)
model.compile(optimizer="adam")
model.fit(ds, callbacks=[cb], epochs=10)
Epoch 1/15
475/475 [==============================] - 15s 22ms/step - loss: 0.5407
target:     <one six>----
prediction: <four four>--

target:     <six seven>--
prediction: <three three>

target:     <seven three>
prediction: <three three>

target:     <three nine>-
prediction: <five three>-
...
...
Epoch 15/15
475/475 [==============================] - 11s 23ms/step - loss: 0.0409
target:     <one six>----
prediction: <one six>----

target:     <six seven>--
prediction: <six seven>--

target:     <seven three>
prediction: <seven three>

target:     <three nine>-
prediction: <three nine>-

Design

Heavily inspired by huggingface, keras and built on top of pytorch and tensorflow 2.x. The idea is to only write the necessary details about your model and data, and train a transformer model without any extra code. Callbacks provide an easy way to make the training code flexible.

All the models contain the entire architecture as a nn.Module or a keras.Model so that you can easily use them in other setups. (eg. using the encoder part to extract input representations and using them for a classification task)

The progress bars have been coded using tf.keras.utils.Progbar.

Custom Callback

Callbacks allow you to execute your code at the following points in the training loop.

  • Before the training loop starts - by overriding on_train_start(self, model)
  • At the start of every epoch - by overriding on_epoch_start(self, model)
  • At the end of every epoch - by overriding on_epoch_end(self, model)
  • When the training ends - by overriding on_fit_end(self, model)

For tensorflow, any custom Keras callback can be passed into the fit function.

For pytorch, simply extend the base xformer.callbacks.Callback class, and override the appropriate methods. Each of these methods have an argument: model. This argument contains the model object, so you can use it for everything you can use the original model object for.

Example: Callback for printing out the output of few inputs at the every epoch.

# Step 1: Extend base Callback class
from xformer.callbacks import Callback

class ExactMatchAccuracy(Callback):

# Step 2: Initialise with necessary objects to execute custom code
    def __init__(self, input_batch, trg_tokenizer):
        self.batch = input_batch
        self.tokenizer = trg_tokenizer
        
# Step 3: Override appropriate method to execute your code in the training loop
# The argument 'model' is the transformer model object being trained.
    def on_epoch_end(self, model):
        model.model.eval()  # model.model contains the `nn.Module`
        enc_inp = self.batch['enc_inp']
        dec_out = self.batch['dec_out']
        bs = batch['enc_inp'].shape[0]
        trg_bos_idx = self.tokenizer.trg_bos_idx
        trg_eos_idx = self.tokenizer.trg_eos_idx
        for i in range(bs):
            preds,_ = model.generate(enc_inp[i,:], trg_bos_idx, trg_eos_idx)
            pred_text = self.tokenizer.decode(preds)
            target_text = self.tokenizer.decode(dec_out[i,:].numpy())
            pad_token = self.tokenizer.idx_to_token[self.tokenizer.pad_token_idx]
            target_text = target_text.replace(pad_token,'').strip()  # remove pads
            print('output:', pred_text)
            print('target:', target_text)

Dev Notes

  • Write minimum usable version with core functionality
  • Test machine translation with small Eng-French dataset.
  • Write API Reference
  • Test speech to text
  • Match notations between tf and pytorch modules
  • Improve code quality and docstrings
  • Put package on pip
  • Multi GPU support
  • Support for different schedulers and optimizers
  • Upload some pretrained models

Acknowledgements

xformer's People

Contributors

apoorvnandan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.