Coder Social home page Coder Social logo

selective_distillation's Introduction

Selective Knowledge Distillation for Neural Machine Translation

This is the PyTorch implementation of paper: Selective Knowledge Distillation for Neural Machine Translation (ACL2021).

We carry out our experiments on standard Transformer with the fairseq toolkit. If you use any source code included in this repo in your work, please cite the following paper.

@article{wang2021selective,
  title={Selective Knowledge Distillation for Neural Machine Translation},
  author={Wang, Fusheng and Yan, Jianhao and Meng, Fandong and Zhou, Jie},
  journal={arXiv preprint arXiv:2105.12967},
  year={2021}
}

Runtime Environment

  • OS: Ubuntu 16.04.1 LTS 64 bits
  • Python version >=3.6
  • Pytorch version >=1.4
  • To install fairseq and develop locally:
    cd fairseq
    pip install --estable ./
    

Training

For selective distillation: First, you need train a teacher model, the training script is the same with fairseq.

Second, train selective distillation model. The training script is the same with fairseq, except for the following arguments:

  • add --use-distillation for openning knowledge distillation method.
  • add --teacher-ckpt-path for adding the path of teacher model which has been trained in first step.
  • add --distil-strategy for selecting distillation strategy, such as batch_level, global_level .
  • add --distil-rate , the hyper-parameter $r$ control the number of words to get distillation knowledge, which is 0.5 in this paper .
  • add --difficult-queue-size, the hyper-parameter $Q_{size}$ which control the size of global queue. And it does not need to set when use batch_level strategy. In our method, the most suitable value is 30k for WMT'14 En-De and 50k for WMT'19 Zh-En.

For example, the script for global-level training on WMT'14 En-De. The script of WMT'19 is the same with WMT'14 EN-De.

output_dir=directory_of_output
teacher_ckpt=path_of_teacher_ckpt/teacher.pt
data_dir=directory_of_data_bin
distil_strategy=batch_level
disitl_rate=0.5
queue_size=30000


export CUDA_VISIBLE_DEVICES=0,1,2,3

fairseq-train $data_dir --arch transformer_wmt_en_de \
    --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --save-dir $output_dir \
    --max-update 300000 --save-interval-updates 5000 \
    --keep-interval-updates 40 \
    --encoder-normalize-before --decoder-normalize-before \
    --lr 7e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --weight-decay 0.0 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 4, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --update-freq 2 --no-epoch-checkpoints \
    --use-distillation --teacher-ckpt-path $teacher_ckpt  --distil-strategy $distil_strategy --distil-rate $disitl_rate \
    --difficult-queue-size $queue_size

Note

  • We need to test every checkpoints separately on validation set, and choose the checkpoint which performs the best. Since the checkpoint_best.pt and the log generated by default may be not right.

selective_distillation's People

Contributors

leslieoverfitting avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.