Coder Social home page Coder Social logo

alone_seq2seq's Introduction

All Word Embeddings from One Embedding

This repository contains source files we used in our paper

All Word Embeddings from One Embedding

Sho Takase, Sosuke Kobayashi

Requirements

  • PyTorch version >= 1.4.0
  • Python version >= 3.6
  • For training new models, you'll also need an NVIDIA GPU and NCCL

Machine Translation

We modified code to control output length, and thus the results might be slightly different from our paper.

Please use released version to reproduce machine translation results in our paper.

Training

1. Download and pre-process datasets following the description in this page
2. Train model

For binary mask with D_{inter} = 8192 using 4GPUs

python -u train.py \
    pre-processed-data-dir \
    --arch transformer_wmt_en_de --optimizer adam --adam-betas '(0.9, 0.98)' \
    --clip-norm 1.0 --lr 0.0015 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --warmup-init-lr 1e-07 --dropout 0.2 --weight-decay 0.0 --criterion label_smoothed_cross_entropy \
    --label-smoothing 0.1 --max-tokens 3584 --min-lr 1e-09 --update-freq 32  --log-interval 100  --max-update 100000 \
    --one-emb binary --one-emb-relu-dropout 0.15 \
    --one-emb-layernum 2 --one-emb-inter-dim 8192  \
    --share-all-embeddings --stop-relu-dropout-update 4500 --save-dir model-save-dir

When you want to convert binary mask into real number filter, please set the following arguments:

    --one-emb real --one-emb-relu-dropout 0.2

Test (decoding)

Averaging latest 10 checkpoints.

python scripts/average_checkpoints.py --inputs model-save-dir --num-epoch-checkpoints 10 --output model-save-dir/averaged.pt

Decoding with the averaged checkpoint.

python generate.py pre-processed-data-dir --path model-save-dir/averaged.pt  --beam 4 --lenpen 0.6 --remove-bpe | grep '^H' | sed 's/^H\-//g' | sort -t ' ' -k1,1 -n | cut -f 3-

Summarization

In our paper, we used old pytorch and fairseq.

Please use this code to reproduce summarization results in our paper.

Training

For binary mask with D_{inter} = 1024 using 4GPUs

python -u train.py \
    pre-processed-data-dir \
    --arch transformer_wmt_en_de --optimizer adam --adam-betas '(0.9, 0.98)' \
    --clip-norm 1.0 --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --warmup-init-lr 1e-07 --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy \
    --label-smoothing 0.1 --max-tokens 3584 --min-lr 1e-09 --update-freq 16 --log-interval 100 --max-epoch 100 \
    --one-emb binary --one-emb-relu-dropout 0.1 \
    --one-emb-layernum 2 --one-emb-inter-dim 1024 \
    --share-all-embeddings --stop-relu-dropout-update 300 \
    --represent-length-by-lrpe --ordinary-sinpos --save-dir model-save-dir

Test (decoding)

Averaging latest 10 checkpoints.

python scripts/average_checkpoints.py --inputs model-save-dir --num-epoch-checkpoints 10 --output model-save-dir/averaged.pt

Decoding with the averaged checkpoint.

python generate.py pre-processed-data-dir --path model-save-dir/averaged.pt  --beam 5 --desired-length 75

For comparison with the reported scores, use reranking following this procedure.

Acknowledgements

A large portion of this repo is borrowed from fairseq.

alone_seq2seq's People

Contributors

takase avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.