Coder Social home page Coder Social logo

awd-lstm-lm's Introduction

AWD-LSTM Language Model

Averaged Stochastic Gradient Descent with Weight Dropped LSTM

This repository contains the code used for Salesforce Research's Regularizing and Optimizing LSTM Language Models paper, originally forked from the PyTorch word level language modeling example. The model comes with instructions to train a word level language model over the Penn Treebank (PTB) and WikiText-2 (WT2) datasets, though the model is likely extensible to many other datasets.

  • Install PyTorch 0.1.12_2
  • Run getdata.sh to acquire the Penn Treebank and WikiText-2 datasets
  • Train the base model using main.py
  • Finetune the model using finetune.py
  • Apply the continuous cache pointer to the finetuned model using pointer.py

If you use this code or our results in your research, please cite:

@article{merityRegOpt,
  title={{Regularizing and Optimizing LSTM Language Models}},
  author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard},
  journal={arXiv preprint arXiv:1708.02182},
  year={2017}
}

Software Requirements

This codebase requires Python 3 and PyTorch 0.1.12_2.

Note the older version of PyTorch - upgrading to later versions would require minor updates and would prevent the exact reproductions of the results below. Pull requests which update to later PyTorch versions are welcome, especially if they have baseline numbers to report too :)

Experiments

The codebase was modified during the writing of the paper, preventing exact reproduction due to minor differences in random seeds or similar. The guide below produces results largely similar to the numbers reported.

For data setup, run ./getdata.sh. This script collects the Mikolov pre-processed Penn Treebank and the WikiText-2 datasets and places them in the data directory.

Important: If you're going to continue experimentation beyond reproduction, comment out the test code and use the validation metrics until reporting your final results. This is proper experimental practice and is especially important when tuning hyperparameters, such as those used by the pointer.

Penn Treebank (PTB)

The instruction below trains a PTB model that without finetuning achieves perplexities of 61.2 / 58.9 (validation / testing), with finetuning achieves perplexities of 58.8 / 56.6, and with the continuous cache pointer augmentation achieves perplexities of 53.5 / 53.0.

First, train the model:

python main.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt

The first epoch should result in a validation perplexity of 308.03.

To then fine-tune that model:

python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt

The validation perplexity after the first epoch should be 60.85.

Note: Fine-tuning modifies the original saved model in PTB.pt - if you wish to keep the original weights you must copy the file.

Finally, to run the pointer:

python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000

Note that the model in the paper was trained for 500 epochs and the batch size was 40, in comparison to 300 and 20 for the model above. The window size for this pointer is chosen to be 500 instead of 2000 as in the paper.

Note: BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result.

WikiText-2 (WT2)

The instruction below train a WT2 model that without finetuning achieves perplexities of 69.1 / 66.1 (validation / testing), with finetuning achieves perplexities of 68.7 / 65.8, and with the continuous cache pointer augmentation achieves perplexities of 53.6 / 52.0 (51.95 specifically).

python main.py --seed 20923 --epochs 750 --data data/wikitext-2 --save WT2.pt

The first epoch should result in a validation perplexity of 629.93.

python -u finetune.py --seed 1111 --epochs 750 --data data/wikitext-2 --save WT2.pt

The validation perplexity after the first epoch should be 69.14.

Note: Fine-tuning modifies the original saved model in PTB.pt - if you wish to keep the original weights you must copy the file.

Finally, run the pointer:

python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2

Note: BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result.

Speed

All the augmentations to the LSTM, including our variant of DropConnect (Wan et al. 2013) termed weight dropping which adds recurrent dropout, allow for the use of NVIDIA's cuDNN LSTM implementation. PyTorch will automatically use the cuDNN backend if run on CUDA with cuDNN installed. This ensures the model is fast to train even when convergence may take many hundreds of epochs.

The default speeds for the model during training on an NVIDIA Quadro GP100:

  • Penn Treebank: approximately 45 seconds per epoch for batch size 40, approximately 65 seconds per epoch with batch size 20
  • WikiText-2: approximately 105 seconds per epoch for batch size 80

Speeds are approximately three times slower on a K80. On a K80 or other memory cards with less memory you may wish to enable the cap on the maximum sampled sequence length to prevent out-of-memory (OOM) errors, especially for WikiText-2.

If speed is a major issue, SGD converges more quickly than our non-monotonically triggered variant of ASGD though achieves a worse overall perplexity.

awd-lstm-lm's People

Contributors

smerity avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.