Coder Social home page Coder Social logo

mixer's Introduction

MIXER - Sequence Level Training with Recurrent Neural Networks

http://arxiv.org/abs/1511.06732

This is a self contained software accompanying the paper titled: Sequence Level Training with Recurrent Neural Networks: http://arxiv.org/abs/1511.06732. The code allows you to reproduce our result on the machine translation task.

The code implements MIXER; it runs both training and evaluation.

Preparing the training data

run prepareData.sh

Examples

Here are some examples of how to use the code.

  • To run an LSTM with the default parameter setting used to generate MIXER's entry for machine translation (see table in fig.5 of http://arxiv.org/abs/1511.06732), type
th -i main.lua
  • To run an LSTM with following hyper-parameters: ** hidden units: 128 ** minibatch size: 64 ** learning rate: 0.1 ** number of time steps we unfold: 15 type
th -i main.lua -nhid 128 -bsz 64 -lr 0.1 -bptt 15

To list all the options available, you need to type

th main.lua --help

Requirements

The software is written in Lua. It requires the following packages:

  • Torch 7
  • nngraph
  • cutorch
  • cunn It runs on standard Linux box with GPU.

Installing

Download the files in an appropriate directory and run the code from there. See below.

How it works

The top level file is called main.lua. In order to run the code you need to run the file using torch. For example:

th -i main.lua -<option1_name> option1_val -<option2_name> option2_val ...

Structure of the code.

  • main.lua this is the scripts that launches training and testing. The user can pass options to set various hyper-parameters, such as learning rate, number of hidden units, etc.
  • Trainer.lua this is a simple class that loops over the dataset a certain number of epochs to train the model, that loops over the validation/test set to evaluate and that backups.
  • model_factory.lua this is a function which returns the network operating at a single time step.
  • Mixer.lua this is the class which implements the unrolled recurrent network, cloning as many steps as necessary whatever is returned by model_factory. It implements the basic the basic fprop/bprop through the recurrent model.
  • ReinforceSampler.lua class that is used to sample from a tensor storing log-probabilities.
  • ReinforceCriterion.lua criterion which is used to compute reward once the end of sequence is reached.
  • ClassNLLCriterionWeighted.lua wrapper around ClassNLLCriterion which multiplies the output of ClassNLLCriterion by a scalar to weigh the loss.
  • LinearNoBackpropInput.lua just like Linear but without computing derivatives w.r.t. input.
  • DataSource.lua data provider that takes as input a tokenized dataset in binary format and returnes mini-batches.
  • reward_factory.lua class that is used to compute BLEU and ROUGE scores (both at the sentence and corpus level).
  • util.lua auxiliary functions.

License

"MIXER"'s software is BSD-licensed. We also provide an additional patent grant.

Other Details

See the CONTRIBUTING file for how to help out.

mixer's People

Contributors

ranzato avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mixer's Issues

cannot load libcutorch.so even I had 'luarocks install cutorch success'

envy@ub1404:/os_pri/github/MIXER$ th -i main.lua
/home/envy/torch/install/bin/luajit: /home/envy/torch/install/share/lua/5.1/trepl/init.lua:384: /home/envy/torch/install/share/lua/5.1/cutorch/init.lua:2: cannot load '/home/envy/torch/install/lib/lua/5.1/libcutorch.so'
stack traceback:
[C]: in function 'error'
/home/envy/torch/install/share/lua/5.1/trepl/init.lua:384: in function 'require'
main.lua:20: in main chunk
[C]: in function 'dofile'
...envy/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk
[C]: at 0x00406260
envy@ub1404:
/os_pri/github/MIXER$

envy@ub1404:/os_pri/github/MIXER$ ll /home/envy/torch/install/lib/lua/5.1/libcutorch.so
-rwxr-xr-x 1 envy envy 568530 3ๆœˆ 9 21:09 /home/envy/torch/install/lib/lua/5.1/libcutorch.so*
envy@ub1404:
/os_pri/github/MIXER$

ReinforceSampler update Gradient issue

The training runs into problem after a few epochs, can you help me with that?

Epoch: 1. Training time: 731.75s. WordsXE/s: 3586.04, WordsRF/s: 53.08
Training: Ent: 6.79167 || Ppl: 110.78875 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 6.00815 || Ppl: 64.36271
Epoch: 2. Training time: 711.39s. WordsXE/s: 3688.66, WordsRF/s: 54.60
Training: Ent: 5.42869 || Ppl: 43.07238 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 5.50559 || Ppl: 45.43051
Epoch: 3. Training time: 712.21s. WordsXE/s: 3684.43, WordsRF/s: 54.54
Training: Ent: 4.97243 || Ppl: 31.39422 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 5.15709 || Ppl: 35.68114
Epoch: 4. Training time: 710.02s. WordsXE/s: 3695.77, WordsRF/s: 54.71
Training: Ent: 4.67853 || Ppl: 25.60814 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.93649 || Ppl: 30.62186
Epoch: 5. Training time: 713.61s. WordsXE/s: 3677.21, WordsRF/s: 54.43
Training: Ent: 4.46283 || Ppl: 22.05195 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.73206 || Ppl: 26.57623
Epoch: 6. Training time: 711.31s. WordsXE/s: 3689.08, WordsRF/s: 54.61
Training: Ent: 4.29383 || Ppl: 19.61420 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.62939 || Ppl: 24.75053
Epoch: 7. Training time: 707.76s. WordsXE/s: 3707.59, WordsRF/s: 54.88
Training: Ent: 4.15686 || Ppl: 17.83768 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.62608 || Ppl: 24.69392
Epoch: 8. Training time: 716.06s. WordsXE/s: 3664.60, WordsRF/s: 54.24
Training: Ent: 4.04386 || Ppl: 16.49387 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.61798 || Ppl: 24.55557
Epoch: 9. Training time: 712.12s. WordsXE/s: 3684.88, WordsRF/s: 54.54
Training: Ent: 3.95212 || Ppl: 15.47767 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.45345 || Ppl: 21.90900
Epoch: 10. Training time: 722.68s. WordsXE/s: 3631.05, WordsRF/s: 53.75
Training: Ent: 3.86949 || Ppl: 14.61611 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.43678 || Ppl: 21.65734
Epoch: 11. Training time: 714.16s. WordsXE/s: 3674.35, WordsRF/s: 54.39
Training: Ent: 3.80143 || Ppl: 13.94266 || Avg. reward/token: 0.000 || Cum. reward error: 0.000 || Nsamples Xent: 2624082 || Nsamples Rf: 38842
Validation: Ent: 4.43160 || Ppl: 21.57964/data/users/v-wenhch/torch/install/bin/luajit: bad argument #2 to '?' (out of range)
stack traceback:
[C]: at 0x7facb696f2f0
[C]: in function '__newindex'
/data/users/v-wenhch/MIXER/ReinforceSampler.lua:54: in function 'updateGradInput'
...v-wenhch/torch/install/share/lua/5.1/nngraph/gmodule.lua:420: in function 'neteval'
...v-wenhch/torch/install/share/lua/5.1/nngraph/gmodule.lua:454: in function 'updateGradInput'
...users/v-wenhch/torch/install/share/lua/5.1/nn/Module.lua:31: in function 'backward'
/data/users/v-wenhch/MIXER/Mixer.lua:289: in function 'train_one_batch'
/data/users/v-wenhch/MIXER/Trainer.lua:67: in function 'train'
/data/users/v-wenhch/MIXER/Trainer.lua:134: in function 'run'
main.lua:176: in main chunk
[C]: in function 'dofile'
...nhch/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk
[C]: at 0x00405d50

Is this a bug?

I shrink the data of train/test/valid to 1000, 100, 200 lines respectively for debug ,
however, it report a error at:

local ctr = 0
for i, v in pairs(self.all_targets) do
    local gross_size = v:size(1)
    if gross_size >= self.bin_thresh then
        ctr = ctr + 1
        self.shard_ids[ctr] = i
    end
end
-- create a permutation vector of the shards
if self.dtype == 'train' then
    self.perm_vec = torch.randperm(#self.shard_ids)
else
     **# self.perm_vec = torch.range(1, #self.shard_ids)  !!!!!!!!!  #self.shard_ids is 0 !!!!!**
end
self.curr_shard_num = 0
self.curr_shard_id = -1

Question about online baseline prediction in MIXER

This issue may be seen as extension of #4.
However, I am making new issue to ask more specific question.

I understand that average reward term (\bar(r){t}) comes from 'online baseline prediction' method introduced in Reinforcement Learning Neural Turing Machine (RL-NTM) Paper. Specifically, this term seems coming from E_p_theta(a_{t:T}|a{1:(t-1)})R(a{t:T}).

Let me ask some questions related to \bar(r)_{t} . I am grateful for your time and effort to read these questions.

Question1) Why \bar(r){t} is called as 'average reward'? Can you clarify whether \bar(r)_{t} = b_t (=baseline prediction for time t) or not?

Question2) How can we sure that LSTM RNN can estimate E_p_theta(a_{t:T}|a_{1:(t-1)})r(a_{1:T}) by linear regression? Does h_t contain enough information to estimate this term? Especially in earlier time step,it may be hard for network to predict future reward.

Question3) Why affine transformation for linear regression is initialized as Weight = 0, bias = 0.01 ?

Thank you !

Question regarding gradient from Reinforce Criterion

First of all, thanks for sharing source code for your awesome work.
I am trying to apply MIXER objective function to my model.

I am asking two questions about output gradient formula from Reinforce Criterion.
(formula (11) in paper: http://arxiv.org/abs/1511.06732)

Question1) To understand full derivation of gradient, paper recommends to see "Reinforcement learning neural turing machines". Can you clarify which formulas in reference paper (http://arxiv.org/abs/1505.00521) correspond to derivation of gradient?

Question2) As far as I understand, 'T' in formula (11) is length of sequence generated by RNN. (e.g. the number of tokens RNN generates until it outputs 'End of sentence' token (=)). When t = T in formula (11), how can we calculate r_(T+1)? Here is My guess: r_(T+1) comes at time step, where input of RNN is . Is this right?

I really appreciate your help. Thank you :)

cuth?

What on earth is "cuth"?

I installed cutorch (luarocks install cutorch) just in case, but still...

cuth -i main.lua

No command 'cuth' found, did you mean:
Command 'cut' from package 'coreutils' (main)
cuth: command not found

does it so complicated?

I use graph.dot(net.fg, "net.fg ", "net.fg.png") and
graph.dot(net.bg, "net.bg ", "net.bg.png")
to draw the net graph return by
local net, size_hid_layers = mdls.makeNetSingleStep( config.model, dict_target, dict_source) at main.lua

it looks so complicated, really?

net fg
net bg

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.