Coder Social home page Coder Social logo

imatge-upc / skiprnn-2017-telecombcn Goto Github PK

View Code? Open in Web Editor NEW
122.0 122.0 40.0 1.51 MB

Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks (ICLR 2018)

Home Page: https://imatge-upc.github.io/skiprnn-2017-telecombcn/

License: MIT License

Python 100.00%
deep-learning recurrent-neural-networks

skiprnn-2017-telecombcn's Introduction

Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks

Víctor Campos Brendan Jou Jordi Torres Xavier Giro-i-Nieto Shih-Fu Chang
Víctor Campos Brendan Jou Jordi Torres Xavier Giró-i-Nieto Shih-Fu Chang

A joint collaboration between:

logo-bsc logo-google logo-upc logo-columbia
Barcelona Supercomputing Center (BSC) Google Inc. Universitat Politècnica de Catalunya (UPC) Columbia University

Abstract

Recurrent Neural Networks (RNNs) continue to show outstanding performance in sequence modeling tasks. However, training RNNs on long sequences often face challenges like slow inference, vanishing gradients and difficulty in capturing long term dependencies. In backpropagation through time settings, these issues are tightly coupled with the large, sequential computational graph resulting from unfolding the RNN in time. We introduce the Skip RNN model which extends existing RNN models by learning to skip state updates and shortens the effective size of the computational graph. This model can also be encouraged to perform fewer state updates through a budget constraint. We evaluate the proposed model on various tasks and show how it can reduce the number of required RNN updates while preserving, and sometimes even improving, the performance of the baseline RNN models.

 

model

 

Publication

Victor Campos, Brendan Jou, Xavier Giro-i-Nieto, Jordi Torres, and Shih-Fu Chang. "Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks", In International Conference on Learning Representations, 2018.

@inproceedings{campos2018skip,
title={Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks},
author={Campos, V{\'\i}ctor and Jou, Brendan and Gir{\'o}-i-Nieto, Xavier and Torres, Jordi and Chang, Shih-Fu},
booktitle={International Conference on Learning Representations},
year={2018}
}

Code

Dependencies

This code was developed with Python 3.6.0 and TensorFlow 1.13.1. An older version of the code for TensorFlow 1.0.0 is available under the tags menu. To download and install TensorFlow, please follow the official guide.

Using the models

The models are ready to be used with TensorFlow's tf.nn.dynamic_rnn and can be found under src/rnn_cells/skip_rnn_cells.py. We provide four different RNN cells:

  • SkipLSTMCell: single SkipLSTM layer
  • SkipGRUCell: single SkipGRU layer
  • MultiSkipLSTMCell: stack of multiple SkipLSTM layers
  • MultiSkipGRUCell: stack of multiple SkipGRU layers

An usage example can be found below:

import tensorflow as tf
from rnn_cells.skip_rnn_cells import SkipLSTM

# Define constants and hyperparameters
NUM_CELLS = 110
BATCH_SIZE = 256
INPUT_SIZE = 10
COST_PER_SAMPLE = 1e-05

# Placeholder for the input tensor with shape (batch, time, input_dims)
x = tf.placeholder(tf.float32, [None, None, INPUT_SIZE])

# Create SkipLSTM and trainable initial state
cell = SkipLSTMCell(NUM_CELLS)
initial_state = cell.trainable_initial_state(BATCH_SIZE)

# Dynamic RNN unfolding
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32, initial_state=initial_state)

# Split the output into the actual RNN output and the state update gate
rnn_outputs, updated_states = rnn_outputs.h, rnn_outputs.state_gate

# Add a penalization for each state update (i.e. used sample)
budget_loss = tf.reduce_mean(tf.reduce_sum(COST_PER_SAMPLE * updated_states, 1), 0)

PyTorch version

This repository contains a PyTorch implementation of Skip RNN by Albert Berenguel.

Acknowledgments

We would like to especially thank the technical support team at the Barcelona Supercomputing Center, as well as Oscar Mañas for updating the original codebase to TensorFlow 1.13.1, adding TensorBoard support and improving the data loading pipeline.

This work has been supported by the grant SEV2015-0493 of the Severo Ochoa Program awarded by Spanish Government, project TIN2015-65316 by the Spanish Ministry of Science and Innovation contracts 2014-SGR-1051 by Generalitat de Catalunya logo-severo
We gratefully acknowledge the support of NVIDIA Corporation through the BSC/UPC NVIDIA GPU Center of Excellence. logo-gpu_excellence_center
The Image ProcessingGroup at the UPC is a SGR14 Consolidated Research Group recognized and sponsored by the Catalan Government (Generalitat de Catalunya) through its AGAUR office. logo-catalonia
This work has been developed in the framework of the project BigGraph TEC2013-43935-R, funded by the Spanish Ministerio de Economía y Competitividad and the European Regional Development Fund (ERDF). logo-spain

Contact

If you have any general doubt about our work or code which may be of interest for other researchers, please use the public issues section on this github repo. Alternatively, drop us an e-mail at mailto:[email protected].

skiprnn-2017-telecombcn's People

Contributors

oscmansan avatar victorcampos7 avatar xavigiro avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

skiprnn-2017-telecombcn's Issues

How to SKIP time steps

Hi,
I have read several papers on SENTENCE SIMPLIFICATION/SIMPLIFIED REPRESENTATIONS, most of which usually SKIP words with discrete actions and learn the policy through Reinforcement Learning. However, in their implementation I find that the cell states are still computed on each time step. Thus hardly any time is saved during training or inference.
This research is very novel and interesting for it's the first time to modify RNN cells intrinsically to SKIP some information as far as I know, and what IMPRESSED me most was that

The number of skipped time steps can be computed ahead of time.

, which brings up the possibility that the cost of time and computation can be reduced by pre-calculating how many steps is to be skipped.
According to your method, at time step t, if u_t equals 1(state update at time t), then the step skipped till next state update can be directly calculated with

skipped_steps = ceil(1.0 / simga(WS_t + b))

However, in your code of implementation, I am not able to find the corresponding snippets where this computed skipped time steps may be implemented :( . Did I miss it somehow?
If I want to skip the state representation together with the computation cost, must it be done by overriding the dynamic_rnn method, or is it already done somehow by optimization inside tensorflow implementation?
Thank you for your time!

Loss get NAN

Hi. I read your paper very impressively. I want to use your code. However, I have a problem using your code. Below is my code.

class SkipGruModel(models.BaseModel):

    def create_model(self, model_input, vocab_size, num_frames, is_training=True, **unused_params):

        stacked_GRU = MultiSkipGRUCell([gru_size] * number_of_layers)

        loss = 0.0
        with tf.variable_scope("RNN"):
            outputs, state = tf.nn.dynamic_rnn(stacked_GRU, model_input,
                                               sequence_length=num_frames,
                                               dtype=tf.float32)

        aggregated_model = getattr(video_level_models,
                                   FLAGS.video_level_classifier_model)

        return aggregated_model().create_model(
            model_input=state[-1].h,
            vocab_size=vocab_size,
            is_training=is_training,
            **unused_params)

The data comes in a number of 1024 one-dimensional vectors (not fixed). I want to get the probability of each label. Therefore, I used the cross entropy loss function.

However, NAN is output as the value of Loss at first step.

What should i do?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.