Coder Social home page Coder Social logo

poly-encoder's Introduction

Bi-Encoder, Poly-Encoder, and Cross-Encoder for Response Selection Tasks

  • This repository is an unofficial re-implementation of Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring.

  • Special thanks to sfzhou5678! Some of the data preprocessing (dataset.py) and training loop code is adapted from his github repo. However, the model architecture and data representation in that repository do not follow the paper exactly, thus leading to worse performance. I re-implement the model for Bi-Encoder and Poly-Encoder in encoder.py. In addition, the model and data processing pipeline of cross encoder are also implemented.

  • Most of the training code in run.py is adpated from examples in the huggingface repository.

  • The most important architectural difference between this implementation and the original paper is that only one bert encoder is used (instead of two separate ones). Please refer to this issue for details. However, this should not affect the performance much.

  • This repository does not implement all details as in the original paper, for example, learning rate decay by 0.4 when plateau. Also due to limited computing resources, I cannot use the exact parameter settings such as batch size or context length as in the original paper. In addition, a much smaller bert model is used. Feel free to tune them or use larger models if you have more computing resources.

Requirements

  • Please see requirements.txt.

Bert Model Setup

  1. Download BERT model from Google.

  2. Pick the model you like (I am using uncased_L-4_H-512_A-8.zip) and move it into bert_model/ then unzip it.

  3. cd bert_model/ then bash run.sh

Ubuntu Data

  1. Download and unzip the ubuntu data.

  2. Rename valid.txt to dev.txt for consistency.

DSTC 7 Data

  1. Download the data from the official competition site, specifically, download train (ubuntu_train_subtask_1.json), valid (ubuntu_dev_subtask_1.json), test (ubuntu_responses_subtask_1.tsv, ubuntu_test_subtask_1.json) split of subtask 1 and put them in the dstc7/ folder.

  2. cd dstc7/ then bash parse.sh

DSTC 7 Augmented Data (from ParlAI)

  1. This dataset setting does not work for cross encoder. For details, please refer to this issue.

  2. Download the data from ParlAI website and keep only ubuntu_train_subtask_1_augmented.json.

  3. Move ubuntu_train_subtask_1_augmented.json into dstc7_aug/ then python3 parse.py.

  4. Copy the dev.txt and test.txt file from dstc7/ into dstc7_aug/ since only training file is augmented.

  5. You can refer to the original post discussing the construction of this augmented data.

Run Experiments (on dstc7)

  1. Train a Bi-Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi
  2. Train a Poly-Encoder with 16 codes:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16
  3. Train a Cross-Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross
  4. Simply change the name of directories to ubuntu and run experiments on the ubuntu dataset.

Inference

  1. Test on Bi_Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi --eval
  2. Test on Poly_Encoder with 16 codes:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16 --eval
  3. Test on Cross_Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross --eval

Results

  • All the experiments are done on a single GTX 1080 GPU with 8G memory and i7-6700K CPU @ 4.00GHz.

  • Default parameters in run.py are used, please refer to run.py for details.

  • The results are calculated on sampled portion (1000 instances) of dev set.

  • da = data augmentation, we only report one result with poly vectors=64 and bert-base (uncased_L-12_H-768_A-12) with data augmentation (dstc7_aug). This result is really close to numbers reported in the original paper.

Ubuntu:

Model R@1 R@2 R@5 R@10 MRR
Bi-Encoder 0.760 0.855 0.971 1.00 0.844
Poly-Encoder 16 0.766 0.868 0.974 1.00 0.851
Poly-Encoder 64 0.767 0.880 0.979 1.00 0.854
Poly-Encoder 360 0.754 0.858 0.970 1.00 0.842

DSTC 7:

Model R@1 R@2 R@5 R@10 MRR
Bi-Encoder 0.437 0.524 0.644 0.753 0.538
Poly-Encoder 16 0.447 0.534 0.668 0.760 0.550
Poly-Encoder 64 0.438 0.540 0.668 0.755 0.546
Poly-Encoder 360 0.453 0.553 0.665 0.751 0.545
Cross-Encoder 0.502 0.595 0.712 0.790 0.599
da + bert base 0.561 0.659 0.765 0.858 0.659

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.