Coder Social home page Coder Social logo

ur-lstm's Introduction

UR-LSTM

Description

This repository revolves around the paper: Improving the Gating Mechanism of Recurrent Neural Networks by Albert Gu, Caglar Gulcehre, Tom Paine, Matt Hoffman and Razvan Pascanu.

In it, the authors introduce the UR-LSTM, a variant of the LSTM architecture which robustly improves the performance of the recurrent model, particularly when long-term dependencies are involved.

Unfortunately, to my knowledge the authors did not release any code, either for the model or experiments - although they did provide pseudo-code for the model. Since I thought it was a really cool read, I decided to reimplement the model as well as some of the experiments with the Pytorch framework.

I've separated the code for the UR-LSTM, which is packaged and downloadable as a standalone module, from the code for the experiments. If you want to check out how to run them, go check this page.

Installation

With Python 3.6 or higher:

pip install ur-lstm-torch

I haven't checked if the model is compatible with older versions of Pytorch, but it should be fine for everything past version 1.0.

Usage

The model can be used in the same way as the native LSTM implementation (doc is here), although I didn't implement the bidirectionnal variant and removed the bias keyword argument:

import torch
from ur_lstm import URLSTM

input_size = 10
hidden_size = 20
num_layers = 2
batch_first = False
dropout = .5

model = URLSTM(
    input_size, hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout
)

batch_size = 3
seq_length = 5

x = torch.randn(seq_length, batch_size, input_size)
out, state = model(x)

print(out.shape) # (seq_length, batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (num_layers, batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

If you want to implement a custom model, you can also import and use the URLSTMCell module in the same way you would the regular LSTMCell (doc is here), although again I removed the bias keyword argument:

import torch
from ur_lstm import URLSTMCell

input_size = 10
hidden_size = 20

cell = URLSTMCell(input_size, hidden_size)

batch_size = 2

x = torch.randn(batch_size, input_size)
state = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)
out, state = cell(x, state)

print(out.shape) # (batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

License

MIT

ur-lstm's People

Contributors

r0maink avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.