This repository contains the official TensorFlow implementation of the following paper:
Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences
by Andis Draguns, Emīls Ozoliņš, Agris Šostaks, Matīss Apinis, Kārlis Freivalds
[arXiv]
Abstract: Attention is a commonly used mechanism in sequence processing, but it is of O(n²) complexity which prevents its application to long sequences. The recently introduced neural Shuffle-Exchange network offers a computation-efficient alternative, enabling the modelling of long-range dependencies in O(n log n) time. The model, however, is quite complex, involving a sophisticated gating mechanism derived from the Gated Recurrent Unit.
In this paper, we present a simple and lightweight variant of the Shuffle-Exchange network, which is based on a residual network employing GELU and Layer Normalization. The proposed architecture not only scales to longer sequences but also converges faster and provides better accuracy. It surpasses the Shuffle-Exchange network on the LAMBADA language modelling task and achieves state-of-the-art performance on the MusicNet dataset for music transcription while being efficient in the number of parameters.
We show how to combine the improved Shuffle-Exchange network with convolutional layers, establishing it as a useful building block in long sequence processing applications.
Residual Shuffle-Exchange networks are a simpler and faster replacement for the recently proposed Neural Shuffle-Exchange network architecture. It has O(n log n) complexity and enables processing of sequences up to a length of 2 million symbols where standard methods fail (e.g., attention mechanisms). The Residual Shuffle-Exchange can serve as a useful building block for long sequence processing applications.
Click the gif to see the full video on YouTube:
Our paper describes Residual Shuffle-Exchange networks in detail and provides full results on long binary addition, long binary multiplication, sorting tasks, the LAMBADA question answering task and multi-instrument musical note recognition using the MusicNet dataset.
Here are the accuracy results on the MusicNet transcription task of identifying the musical notes performed from audio waveforms (freely-licensed classical music recordings):
Model | Learnable parameters (M) | Average precision score (%) |
---|---|---|
cgRNN | 2.36 | 53.0 |
Deep Real Network | 10.0 | 69.8 |
Deep Complex Network | 8.8 | 72.9 |
Complex Transformer | 11.61 | 74.22 |
Translation-invariant net | unknown | 77.3 |
Residual Shuffle-Exchange network | 3.06 | 78.02 |
Note: Our used model achieves state-of-the-art performance while being efficient in the number of parameters using the audio waveform directly compared to the previous state-of-the-art models that used specialised architectures with complex number representations of the Fourier-transformed waveform.
Here are the accuracy results on the LAMBADA question answering task of predicting a target word in its broader context (on average 4.6 sentences picked from novels):
Model | Learnable parameters (M) | Test accuracy (%) |
---|---|---|
Random word from passage | - | 1.6 |
Gated-Attention Reader | unknown | 49.0 |
Neural Shuffle-Exchange network | 33 | 52.28 |
Residual Shuffle-Exchange network | 11 | 54.34 |
Universal Transformer | 152 | 56.0 |
Human performance | - | 86.0 |
GPT-3 | 175000 | 86.4 |
Note: Our used model works faster and can be evaluated on 4 times longer sequences using the same amount of GPU memory compared to the Shuffle-Exchange network model and on 128 times longer sequences than the Universal Transformer model.
Residual Shuffle-Exchange networks are a lightweight variant of the continuous, differentiable neural networks with a regular-layered structure consisting of alternating Switch and Shuffle layers that are Shuffle-Exchange networks.
The Switch Layer divides the input into adjacent pairs of values and applies a Residual Switch Unit, a learnable 2-to-2 function, to each pair of inputs producing two outputs, employing GELU and Layer Normalization.
Here is an illustration of a Residual Switch Unit, which replaces the Switch Unit from Shuffle-Exchange networks:
The Shuffle Layer follows where inputs are permuted according to a perfect-shuffle permutation (i.e., how a deck of cards is shuffled by splitting it into halves and then interleaving them) – a cyclic bit shift rotating left in the first part of the network and (inversely) rotating right in the second part.
The Residual Shuffle-Exchange network is organized in blocks by alternating these two kinds of layers in the pattern of the Beneš network. Such a network can represent a wide class of functions including any permutation of the input values.
Here is an illustration of a whole Residual Shuffle-Exchange network model consisting of two blocks with 8 inputs:
- Python 3.6 or higher.
- TensorFlow 1.14.0.
To start training the Residual Shuffle-Exchange network on binary addition, run the terminal command:
python3 RSE_trainer.py
To select the sequence processing task for which to train the Residual Shuffle-Exchange network edit the config.py
file that contains various hyperparameter and other suggested setting options.
For the MusicNet transcription task see the following:
...
"""
Task configuration.
"""
...
# task = "musicnet"
# input_type = tf.float32
...
To download and parse the MusicNet dataset, run:
wget https://homes.cs.washington.edu/~thickstn/media/musicnet.npz
python3 -u resample.py musicnet.npz musicnet_11khz.npz 44100 11000
rm musicnet.npz
python3 -u parse_file.py
rm musicnet_11khz.npz
This might take a while. After parsing the file, make sure that config.py contains the correct directory for the MusicNet data. To test the trained model for the MusicNet task on the test set, run tester.py.
For the LAMBADA question answering task see the following:
...
"""
Task configuration.
"""
...
# task = "lambada"
# n_input = lambada_vocab_size
# n_output = 3
# n_hidden = 48*8
# #input_dropout_keep_prob = 1.0
# input_word_dropout_keep_prob = 0.95
# use_front_padding = True
# use_pre_trained_embedding = True
# disperse_padding = False
# label_smoothing = 0.1
# batch_size = 64
# bins = [256]
...
To download the LAMBADA dataset see the original publication by Paperno et al.
To download the pre-trained fastText 1M English word embedding see the downloads section of the FastText library website and extract to directory listed in the config.py
file variable base_folder
under “Embedding configuration”:
...
"""
Embedding configuration
"""
use_pre_trained_embedding = False
base_folder = "/host-dir/embeddings/"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...
To enable the pre-trained embedding change the config.py
file variable use_pre_trained_embedding
to True
:
...
use_pre_trained_embedding = True
...
To start training the Residual Shuffle-Exchange network use the terminal command:
python3 DNGPU_trainer.py
If you're running Windows, before starting training the Residual Shuffle-Exchange network edit the config.py
file to change the directory-related variables to Windows file path format:
...
"""
Local storage (checkpoints, etc).
"""
...
out_dir = ".\host-dir\gpu" + gpu_instance
model_file = out_dir + "\\varWeights.ckpt"
image_path = out_dir + "\\images"
...
"""
MusicNet configuration
"""
musicnet_data_dir = ".\host-dir\musicnet\musicnet"
...
"""
Lambada configuration
"""
lambada_data_dir = ".\host-dir\lambada-dataset"
...
"""
Embedding configuration
"""
...
base_folder = ".\host-dir\embeddings"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...
For help or issues using Residual Shuffle-Exchange networks, please submit a GitHub issue.
For personal communication related to Residual Shuffle-Exchange networks, please contact Kārlis Freivalds ([email protected]).