Coder Social home page Coder Social logo

rse's Introduction

Residual Shuffle-Exchange Networks: Official TensorFlow Implementation

This repository contains the official TensorFlow implementation of the following paper:

Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences

by Andis Draguns, Emīls Ozoliņš, Agris Šostaks, Matīss Apinis, Kārlis Freivalds

[arXiv]

Abstract: Attention is a commonly used mechanism in sequence processing, but it is of O(n²) complexity which prevents its application to long sequences. The recently introduced neural Shuffle-Exchange network offers a computation-efficient alternative, enabling the modelling of long-range dependencies in O(n log n) time. The model, however, is quite complex, involving a sophisticated gating mechanism derived from the Gated Recurrent Unit.

In this paper, we present a simple and lightweight variant of the Shuffle-Exchange network, which is based on a residual network employing GELU and Layer Normalization. The proposed architecture not only scales to longer sequences but also converges faster and provides better accuracy. It surpasses the Shuffle-Exchange network on the LAMBADA language modelling task and achieves state-of-the-art performance on the MusicNet dataset for music transcription while being efficient in the number of parameters.

We show how to combine the improved Shuffle-Exchange network with convolutional layers, establishing it as a useful building block in long sequence processing applications.

Introduction

Residual Shuffle-Exchange networks are a simpler and faster replacement for the recently proposed Neural Shuffle-Exchange network architecture. It has O(n log n) complexity and enables processing of sequences up to a length of 2 million symbols where standard methods fail (e.g., attention mechanisms). The Residual Shuffle-Exchange can serve as a useful building block for long sequence processing applications.

Demo

Click the gif to see the full video on YouTube:

Preview of results

Our paper describes Residual Shuffle-Exchange networks in detail and provides full results on long binary addition, long binary multiplication, sorting tasks, the LAMBADA question answering task and multi-instrument musical note recognition using the MusicNet dataset.

Here are the accuracy results on the MusicNet transcription task of identifying the musical notes performed from audio waveforms (freely-licensed classical music recordings):

Model Learnable parameters (M) Average precision score (%)
cgRNN 2.36 53.0
Deep Real Network 10.0 69.8
Deep Complex Network 8.8 72.9
Complex Transformer 11.61 74.22
Translation-invariant net unknown 77.3
Residual Shuffle-Exchange network 3.06 78.02

Note: Our used model achieves state-of-the-art performance while being efficient in the number of parameters using the audio waveform directly compared to the previous state-of-the-art models that used specialised architectures with complex number representations of the Fourier-transformed waveform.

Here are the accuracy results on the LAMBADA question answering task of predicting a target word in its broader context (on average 4.6 sentences picked from novels):

Model Learnable parameters (M) Test accuracy (%)
Random word from passage - 1.6
Gated-Attention Reader unknown 49.0
Neural Shuffle-Exchange network 33 52.28
Residual Shuffle-Exchange network 11 54.34
Universal Transformer 152 56.0
Human performance - 86.0
GPT-3 175000 86.4

Note: Our used model works faster and can be evaluated on 4 times longer sequences using the same amount of GPU memory compared to the Shuffle-Exchange network model and on 128 times longer sequences than the Universal Transformer model.

What are Residual Shuffle-Exchange networks?

Residual Shuffle-Exchange networks are a lightweight variant of the continuous, differentiable neural networks with a regular-layered structure consisting of alternating Switch and Shuffle layers that are Shuffle-Exchange networks.

The Switch Layer divides the input into adjacent pairs of values and applies a Residual Switch Unit, a learnable 2-to-2 function, to each pair of inputs producing two outputs, employing GELU and Layer Normalization.

Here is an illustration of a Residual Switch Unit, which replaces the Switch Unit from Shuffle-Exchange networks:

The Shuffle Layer follows where inputs are permuted according to a perfect-shuffle permutation (i.e., how a deck of cards is shuffled by splitting it into halves and then interleaving them) – a cyclic bit shift rotating left in the first part of the network and (inversely) rotating right in the second part.

The Residual Shuffle-Exchange network is organized in blocks by alternating these two kinds of layers in the pattern of the Beneš network. Such a network can represent a wide class of functions including any permutation of the input values.

Here is an illustration of a whole Residual Shuffle-Exchange network model consisting of two blocks with 8 inputs:

System requirements

  • Python 3.6 or higher.
  • TensorFlow 1.14.0.

Running the experiments

To start training the Residual Shuffle-Exchange network on binary addition, run the terminal command:

python3 RSE_trainer.py

To select the sequence processing task for which to train the Residual Shuffle-Exchange network edit the config.py file that contains various hyperparameter and other suggested setting options.

For the MusicNet transcription task see the following:

...
"""
    Task configuration.
"""
...
# task = "musicnet"
# input_type = tf.float32
...

To download and parse the MusicNet dataset, run:

wget https://homes.cs.washington.edu/~thickstn/media/musicnet.npz 
python3 -u resample.py musicnet.npz musicnet_11khz.npz 44100 11000
rm musicnet.npz 
python3 -u parse_file.py
rm musicnet_11khz.npz

This might take a while. After parsing the file, make sure that config.py contains the correct directory for the MusicNet data. To test the trained model for the MusicNet task on the test set, run tester.py.

For the LAMBADA question answering task see the following:

...
"""
    Task configuration.
"""
...
# task = "lambada"
# n_input = lambada_vocab_size
# n_output = 3
# n_hidden = 48*8
# #input_dropout_keep_prob = 1.0
# input_word_dropout_keep_prob = 0.95
# use_front_padding = True
# use_pre_trained_embedding = True
# disperse_padding = False
# label_smoothing = 0.1
# batch_size = 64
# bins = [256]
...

To download the LAMBADA dataset see the original publication by Paperno et al.

To download the pre-trained fastText 1M English word embedding see the downloads section of the FastText library website and extract to directory listed in the config.py file variable base_folder under “Embedding configuration”:

...
"""
    Embedding configuration
"""
use_pre_trained_embedding = False
base_folder = "/host-dir/embeddings/"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...

To enable the pre-trained embedding change the config.py file variable use_pre_trained_embedding to True:

...
use_pre_trained_embedding = True
...

To start training the Residual Shuffle-Exchange network use the terminal command:

python3 DNGPU_trainer.py

If you're running Windows, before starting training the Residual Shuffle-Exchange network edit the config.py file to change the directory-related variables to Windows file path format:

...
"""
    Local storage (checkpoints, etc).
"""
...
out_dir = ".\host-dir\gpu" + gpu_instance
model_file = out_dir + "\\varWeights.ckpt"
image_path = out_dir + "\\images"
...
"""
    MusicNet configuration
"""
musicnet_data_dir = ".\host-dir\musicnet\musicnet"
...
"""
    Lambada configuration
"""
lambada_data_dir = ".\host-dir\lambada-dataset"
...
"""
    Embedding configuration
"""
...
base_folder = ".\host-dir\embeddings"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...

Contact information

For help or issues using Residual Shuffle-Exchange networks, please submit a GitHub issue.

For personal communication related to Residual Shuffle-Exchange networks, please contact Kārlis Freivalds ([email protected]).

rse's People

Contributors

andisdraguns avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.