Coder Social home page Coder Social logo

lddsdu / dynamic-memory-networks-in-tensorflow Goto Github PK

View Code? Open in Web Editor NEW

This project forked from barronalex/dynamic-memory-networks-in-tensorflow

0.0 1.0 0.0 50.19 MB

Dynamic Memory Network implementation in TensorFlow

License: MIT License

Python 99.32% Shell 0.68%

dynamic-memory-networks-in-tensorflow's Introduction

Dynamic Memory Networks in TensorFlow

DMN+ implementation in TensorFlow for question answering on the bAbI 10k dataset.

Structure and parameters from Dynamic Memory Networks for Visual and Textual Question Answering which is henceforth referred to as Xiong et al.

Adapted from Stanford's cs224d assignment 2 starter code and using methods from Dynamic Memory Networks in Theano for importing the Babi-10k dataset.

Repository Contents

file description
dmn_plus.py contains the DMN+ model
dmn_train.py trains the model on a specified (-b) babi task
dmn_test.py tests the model on a specified (-b) babi task
babi_input.py prepares bAbI data for input into DMN
attention_gru_cell.py contains a custom Attention GRU cell implementation
fetch_babi_data.sh shell script to fetch bAbI tasks (from DMNs in Theano)

Usage

Install TensorFlow r1.4

Run the included shell script to fetch the data

bash fetch_babi_data.sh

Use 'dmn_train.py' to train the DMN+ model contained in 'dmn_plus.py'

python dmn_train.py --babi_task_id 2

Once training is finished, test the model on a specified task

python dmn_test.py --babi_task_id 2

The l2 regularization constant can be set with -l2-loss (-l). All other parameters were specified by Xiong et al and can be found in the 'Config' class in 'dmn_plus.py'.

Benchmarks

The TensorFlow DMN+ reaches close to state of the art performance on the 10k dataset with weak supervision (no supporting facts).

Each task was trained on separately with l2 = 0.001. As the paper suggests, 10 training runs were used for tasks 2, 3, 17 and 18 (configurable with --num-runs), where the weights which produce the lowest validation loss in any run are used for testing.

The pre-trained weights which achieve these benchmarks are available in 'pretrained'.

I haven't yet had the time to fully optimize the l2 parameter which is not specified by the paper. My hypothesis is that fully optimizing l2 regularization would close the final significant performance gap between the TensorFlow DMN+ and original DMN+ on task 3.

Below are the full results for each bAbI task (tasks where both implementations achieved 0 test error are omitted):

Task ID TensorFlow DMN+ Xiong et al DMN+
2 0.9 0.3
3 18.4 1.1
5 0.5 0.5
7 2.8 2.4
8 0.5 0.0
9 0.1 0.0
14 0.0 0.2
16 46.2 45.3
17 5.0 4.2
18 2.2 2.1

dynamic-memory-networks-in-tensorflow's People

Contributors

ajprax avatar barronalex avatar hsm207 avatar pushpankar avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.