Coder Social home page Coder Social logo

jitgru's Introduction

JitGRU: GRU with PyTorch's TorchScript

A simple implementation of GRUs using PyTorch's JIT (TorchScript). The API follows that of torch.nn.GRU. Should run reasonably fast.

But... why?

At the time of writing, PyTorch does not support second order derivatives for GRUs with CUDA (see this issue). As a result, any loss function that depends on computing the second derivatives of GRUs doesn't work on out of the box. I needed double backward() calls for a project, so here it is!

How to use

The main implementation is available in jit_gru.py. I've implemented equivalents of torch.nn.GRUCell and torch.nn.GRU in that file. Look at the test cases that I've included in the implementation. Those should help you get started.

Bi-Directional GRUs

Support for bi-directional GRUs with variable input lengths was recently added (credits go to @elixir-code). This implementation is available separately in jit_bigru.py. See the included test cases in that file for example usage.

Demo Project

Checkout DeepNAG, which contains a GAN-based sequence generation model, as well as a non-adversarial sequence generator. The GAN-based sequence generator in the aforementioned repository is trained with the improved Wasserstein GAN loss function, and relies on the code from this repository.

Support/Citing

If you find our work useful, please consider starring this repository and citing our work:

@phdthesis{maghoumi2020dissertation,
  title={{Deep Recurrent Networks for Gesture Recognition and Synthesis}},
  author={Mehran Maghoumi},
  year={2020},
  school={University of Central Florida Orlando, Florida}
}

@misc{maghoumi2020deepnag,
      title={{DeepNAG: Deep Non-Adversarial Gesture Generation}}, 
      author={Mehran Maghoumi and Eugene M. Taranta II and Joseph J. LaViola Jr},
      year={2020},
      eprint={2011.09149},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contribution

I'm actively using this implementation, so contributions are greatly welcome as they help my work too. If you think you can improve this project, or implement something more efficiently, then feel free to submit pull requests!

License

This project is licensed under the MIT License.

jitgru's People

Contributors

elixir-code avatar maghoumi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

jitgru's Issues

Support for variable length sequences and bi-directional GRU

I am working with a NLP model which uses bi-directional GRUs and also uses a higher-order derivative in the loss function.

Is it possible to extend this work to support variable length sequences and also support the bi-directional variant of the GRU?

I am interested in understanding if this technically possible and if implementing it using JIT would give speed improvements over the approach of disabling CUDNN with with torch.backends.cudnn.flags(enabled=False): for nn.GRU.

P.S.: I understand that Torchscript does not support PackedSequence making things difficult.

Speed comparison

Does JITGRU provide any speedup for both CPU and GPU training? Thank you

Unit test does not work

Hi, just FYI, I get that with pytorch v1.4.0a0:

python jit_gru.py 
[2, 3]
[2, 3]
[2, 3]
[2, 3]
[2, 3]
Traceback (most recent call last):
  File "models/gru/jit_gru.py", line 184, in <module>
    test_script_gru_layer(5, 2, 3, 7)
  File "models/gru/jit_gru.py", line 146, in test_script_gru_layer
    assert lstm_param.shape == custom_param.shape
AssertionError

And printing:

print(custom_param.shape)
print(lstm_param.shape)

shows:
torch.Size([21, 7])
torch.Size([21, 3])

Error when batch is set to 1

First thank you for sharing this code

I just copied the original code, when the batch is set to 1, there is an error says:

Traceback (most recent call last):
  File "jit.py", line 209, in <module>
    test_script_gru_layer(5, 1, 3, 7)
  File "jit.py", line 161, in test_script_gru_layer
    out, out_state = gru_jit(inp, h)
  File "/home/wu/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
The above operation failed in interpreter.
Traceback (most recent call last):
  File "<string>", line 10
                  alpha: number = 1.0):
            result = torch.add(self, other, alpha=alpha)
            self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            def backward(grad_output):
                grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
  File "<string>", line 10, in AD_sizes_if_not_equal_multi_1
                  alpha: number = 1.0):
            result = torch.add(self, other, alpha=alpha)
            self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            def backward(grad_output):
                grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
  File "<string>", line 10, in AD_sizes_if_not_equal_multi_1
                  alpha: number = 1.0):
            result = torch.add(self, other, alpha=alpha)
            self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            def backward(grad_output):
                grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)

The above operation failed in interpreter.
Traceback (most recent call last):

Do you have idea why this is happening?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.