Coder Social home page Coder Social logo

busizshen / njunmt-tf Goto Github PK

View Code? Open in Web Editor NEW

This project forked from zhaocq-nlp/njunmt-tf

0.0 3.0 0.0 680 KB

An open-source neural machine translation system developed by Natural Language Processing Group, Nanjing University.

License: Apache License 2.0

Python 88.88% Shell 2.15% Prolog 7.18% Perl 1.79%

njunmt-tf's Introduction

NJUNMT-tf

NJUNMT-tf is a general purpose sequence modeling tool in TensorFlow while neural machine translation is the main target task.

Key features

NJUNMT-tf builds NMT models almost from scratch without any high-level TensorFlow APIs which often hide details of many network components and lead to obscure code structure that is difficult to understand and manipulate. NJUNMT-tf only depends on basic TensorFlow modules, like array_ops, math_ops and nn_ops. Each operation in the code is under control.

NJUNMT-tf focuses on modularity and extensibility using standard TensorFlow modules and practices to support advanced modeling capability:

  • arbitrarily complex encoder architectures, e.g. Bidirectional RNN encoder, Unidirectional RNN encoder and self-attention.
  • arbitrarily complex decoder architectures, e.g. Conditional GRU/LSTM decoder, attention decoder and self-attention.
  • hybrid encoder-decoder models, e.g. self-attention encoder and RNN decoder or vice versa.

and all of the above can be used simultaneously to train novel and complex architectures.

The code also supports:

  • model ensemble.
  • learning rate decaying according to loss on evaluation data.
  • model validation on evaluation data with BLEU score and early stop strategy.
  • monitoring with TensorBoard.
  • capability for BPE

Requirements

  • tensorflow (>=1.4)
  • pyyaml

Quickstart

Here is a minimal workflow to get you started in using NJUNMT-tf. This example uses a toy Chinese-English dataset for machine translation with a toy setting.

1. Build the word vocabularies:

python -m bin.generate_vocab testdata/toy.zh --max_vocab_size 100  > testdata/vocab.zh
python -m bin.generate_vocab testdata/toy.en0 --max_vocab_size 100  > testdata/vocab.en

2. Train with preset sequence-to-sequence parameters:

export CUDA_VISIBLE_DEVICES=
python -m bin.train --model_dir test_model \
    --config_paths "
        ./njunmt/example_configs/toy_seq2seq.yml,
        ./njunmt/example_configs/toy_training_options.yml,
        ./default_configs/default_optimizer.yml"

3. Translate a test file with the latest checkpoint:

export CUDA_VISIBLE_DEVICES=
python -m bin.infer --model_dir test_models \
  --infer "
    beam_size: 4
    source_words_vocabulary: testdata/vocab.zh
    target_words_vocabulary: testdata/vocab.en" \
  --infer_data "
    - features_file: testdata/toy.zh
      labels_file: testdata/toy.en
      output_file: toy.trans
      output_attention: false"

Note: do not expect any good translation results with this toy example. Consider training on larger parallel datasets instead.

Configuration

As you can see, there are two ways to manipulate hyperparameters of the process:

  • tf FLAGS
  • yaml-style config file

For example, there is a config file specifying the datasets for training procedure.

# datasets.yml
data:
  train_features_file: testdata/toy.zh
  train_labels_file: testdata/toy.en0
  eval_features_file: testdata/toy.zh
  eval_labels_file: testdata/toy.en
  source_words_vocabulary: testdata/vocab.zh
  target_words_vocabulary: testdata/vocab.en

You can either use the command:

python -m bin.train --config_paths "datasets.yml" ...

or

python -m bin.train --data "
    train_features_file: testdata/toy.zh
    train_labels_file: testdata/toy.en0
    eval_features_file: testdata/toy.zh
    eval_labels_file: testdata/toy.en
    source_words_vocabulary: testdata/vocab.zh
    target_words_vocabulary: testdata/vocab.en" ...

They are of the same effect.

The available FLAGS (or the top levels of yaml configs) for bin.train are as follows:

  • config_paths: the paths for config files
  • model_dir: the directory for saving checkpoints
  • train: training options, e.g. batch size, maximum length
  • data: training data, evaluation data, vocabulary and (optional) BPE codes
  • hooks: a list of training hooks (not provided, in the current version)
  • metrics: a list of validation metrics on evaluation data
  • model: the class name of the model
  • model_params: parameters for the model
  • optimizer_params: parameters for optimizer

The available FLAGS (or the top levels of yaml configs) for bin.infer are as follows:

  • config_paths: the paths for config files
  • model_dir: the checkpoint directory or directories separated by commas for model ensemble
  • infer: inference options, e.g. beam size, length penalty rate
  • infer_data: a list of data file to be translated
  • weight_scheme: the weight scheme for model ensemble (only "average" available now)

Note that:

  • each FLAG should be a string of yaml-style
  • the hyperparameters provided by FLAGS will overwrite those presented in config files
  • illegal parameters will interrupt the program, so see sample.yml of more detailed discription for each parameter.

TODO

The following features remain unimplemented:

  • multi-gpu training
  • schedule sampling
  • minimum risk training

and trustable results on open datasets (WMT) are supposed to be reported.

Acknowledgments

The implementation is inspired by the following:

Contact

Any comments or suggestions are welcome.

Please email [email protected].

njunmt-tf's People

Contributors

zhaocq-nlp avatar

Watchers

James Cloos avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.