Coder Social home page Coder Social logo

poly-encoder's Introduction

Bi-Encoder, Poly-Encoder, and Cross-Encoder for Response Selection Tasks

  • This repository is an unofficial re-implementation of Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring.

  • Special thanks to sfzhou5678! Some of the data preprocessing (dataset.py) and training loop code is adapted from his github repo. However, the model architecture and data representation in that repository do not follow the paper exactly, thus leading to worse performance. I re-implement the model for Bi-Encoder and Poly-Encoder in encoder.py. In addition, the model and data processing pipeline of cross encoder are also implemented.

  • Most of the training code in run.py is adpated from examples in the huggingface repository.

  • The most important architectural difference between this implementation and the original paper is that only one bert encoder is used (instead of two separate ones). Please refer to this issue for details. However, this should not affect the performance much.

  • This repository does not implement all details as in the original paper, for example, learning rate decay by 0.4 when plateau. Also due to limited computing resources, I cannot use the exact parameter settings such as batch size or context length as in the original paper. In addition, a much smaller bert model is used. Feel free to tune them or use larger models if you have more computing resources.

Requirements

  • Please see requirements.txt.

Bert Model Setup

  1. Download BERT model from Google.

  2. Pick the model you like (I am using uncased_L-4_H-512_A-8.zip) and move it into bert_model/ then unzip it.

  3. cd bert_model/ then bash run.sh

Ubuntu Data

  1. Download and unzip the ubuntu data.

  2. Rename valid.txt to dev.txt for consistency.

DSTC 7 Data

  1. Download the data from the official competition site, specifically, download train (ubuntu_train_subtask_1.json), valid (ubuntu_dev_subtask_1.json), test (ubuntu_responses_subtask_1.tsv, ubuntu_test_subtask_1.json) split of subtask 1 and put them in the dstc7/ folder.

  2. cd dstc7/ then bash parse.sh

DSTC 7 Augmented Data (from ParlAI)

  1. This dataset setting does not work for cross encoder. For details, please refer to this issue.

  2. Download the data from ParlAI website and keep only ubuntu_train_subtask_1_augmented.json.

  3. Move ubuntu_train_subtask_1_augmented.json into dstc7_aug/ then python3 parse.py.

  4. Copy the dev.txt and test.txt file from dstc7/ into dstc7_aug/ since only training file is augmented.

  5. You can refer to the original post discussing the construction of this augmented data.

Run Experiments (on dstc7)

  1. Train a Bi-Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi
  2. Train a Poly-Encoder with 16 codes:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16
  3. Train a Cross-Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross
  4. Simply change the name of directories to ubuntu and run experiments on the ubuntu dataset.

Inference

  1. Test on Bi_Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi --eval
  2. Test on Poly_Encoder with 16 codes:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16 --eval
  3. Test on Cross_Encoder:

    python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross --eval

Results

  • All the experiments are done on a single GTX 1080 GPU with 8G memory and i7-6700K CPU @ 4.00GHz.

  • Default parameters in run.py are used, please refer to run.py for details.

  • The results are calculated on sampled portion (1000 instances) of dev set.

  • da = data augmentation, we only report one result with poly vectors=64 and bert-base (uncased_L-12_H-768_A-12) with data augmentation (dstc7_aug). This result is really close to numbers reported in the original paper.

Ubuntu:

Model R@1 R@2 R@5 R@10 MRR
Bi-Encoder 0.760 0.855 0.971 1.00 0.844
Poly-Encoder 16 0.766 0.868 0.974 1.00 0.851
Poly-Encoder 64 0.767 0.880 0.979 1.00 0.854
Poly-Encoder 360 0.754 0.858 0.970 1.00 0.842

DSTC 7:

Model R@1 R@2 R@5 R@10 MRR
Bi-Encoder 0.437 0.524 0.644 0.753 0.538
Poly-Encoder 16 0.447 0.534 0.668 0.760 0.550
Poly-Encoder 64 0.438 0.540 0.668 0.755 0.546
Poly-Encoder 360 0.453 0.553 0.665 0.751 0.545
Cross-Encoder 0.502 0.595 0.712 0.790 0.599
da + bert base 0.561 0.659 0.765 0.858 0.659

Star History

Star History Chart

poly-encoder's People

Contributors

chijames avatar kaisugi avatar lydhr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

poly-encoder's Issues

这个地方为啥取的是第一个词位置的向量?

Poly-Encoder/encoder.py

Lines 20 to 27 in 6f0d9c4

context_vec = self.bert(context_input_ids, context_input_masks)[0][:,0,:] # [bs,dim]
batch_size, res_cnt, seq_length = responses_input_ids.shape
responses_input_ids = responses_input_ids.view(-1, seq_length)
responses_input_masks = responses_input_masks.view(-1, seq_length)
responses_vec = self.bert(responses_input_ids, responses_input_masks)[0][:,0,:] # [bs,dim]
responses_vec = responses_vec.view(batch_size, res_cnt, -1)

一个句子经过bert之后得到整个句子的向量表示,这个地方为啥取的是第一个词位置的向量:[0][:,0,:]

About the implementation of Poly Encoder

Hi @chijames, thanks so much for this wonderful project!
After digging into the code, I have two questions:

  • Is there any special reason why masking is not implemented in this section?

    Poly-Encoder/encoder.py

    Lines 72 to 78 in e5299e3

    def dot_attention(self, q, k, v):
    # q: [bs, poly_m, dim] or [bs, res_cnt, dim]
    # k=v: [bs, length, dim] or [bs, poly_m, dim]
    attn_weights = torch.matmul(q, k.transpose(2, 1)) # [bs, poly_m, length]
    attn_weights = F.softmax(attn_weights, -1)
    output = torch.matmul(attn_weights, v) # [bs, poly_m, dim]
    return output

  • Can we speed up the construction of poly_code_embeddings by using nn.Parameters? In this way, we don't need to create poly_ids and move it to GPU in every batches.

Thanks for your reply!

About Performance

Hi! Thanks a lot for sharing your code!
I got some question about the performance.
You propose the performance of your code on DSTC7 with bi-encoder as follows,
image

However, in the original paper, the performance of bi-encoder on DSTC7 is
image

With your code we can get R@1 for 0.437 but the performance in the original paper is 0.565 on dev set and 0.668 on test set. I read your code carefully but find little difference with the setting in the original paper. I also change your default one-bert to two different bert for bi-encoder, but still cannot get the same performance as that in the original paper. Why?

Hyperparameters when training Cross-Encoder.

Hi! I'm using your code and want to reproduce your result on the DSTC 7 dataset.
When training the Cross-Encoder. I use BERT-small (uncased_L-4_H-512_A-8.zip) and leave all hyperparameters unchanged as in run.py (batch size=32, max context length=128, max response length=32). However I came across OOM on my Tesla M40 GPU, which has a memory of 11G.
I wonder how you can train the cross-encoder on your GPU. I guess the default hyperparameters in run.py are designed for training bi-encoder and poly-encoder. Could you please show me your hyperparameters when training cross-encoder?

Something wrong when calculate t_total

First of all, I really appreciate for the nice repo.

The t_total in run.py is calculated by t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs and the t_total is passed into transformers.get_linear_schedule_with_warmup. This indicates the total number of steps of the training process.

However, I guess the total nember of steps is calculated by the number of batches * epoch. Therefore, the code for calculating t_total should be t_total = len(train_dataloader) // (args.train_batch_size * args.gradient_accumulation_steps) * args.num_train_epochs

If I'm wrong, please let me know what am I missing.

Code understanding

Hi, Work is really great.

I am just trying to understand if labels are None, then encoders are outputting matrices instead of scalers, but you have not made any provision for this in your code.

Also what is neg in cross encoders? Can you please provide some context on the variables that you use using commenting?

Also why bi & poly encoder model use responses as 3-dimensional stuff?

A bug in parse.py

I have noticed that in parse.py, the candidate response is concatenated to context by '\t'. This will lead to mistake when reading this record for training. Considering this case that candidate response is "", which actually exists in the dstc7 dataset, when split this record by '\t' to extract response, the last utterence in context will be chosen.

(this is my first time to submit the issue, I hope I have dipicted the bug clearly.

Config file missing

There doesn’t seem to be the config file you used to run this code, I’m just curious what some of the values you are using are. Specifically hidden size referenced in the poly encoder section to calculate your m.

Licensing

Dear @chijames ,

I came across your Poly-Encoder and would like to adapt it for some work purposes. I was told that I can't use it unless it is open-source licensed. I was wondering if you are willing to allow for that, perhaps through an MIT license etc.?

https://choosealicense.com/licenses/mit/

Hope to hear from you and thank you very much!

Best Regards,
Chor Seng

Are the transformers of bi-encoder trained separately?

(To be honest, I'm not used to "deep learning coding" (PyTorch, Huggingface, etc...), so this might be a silly question. Keep in mind I'm a beginner.)

The original paper said that context encoder and candidate encoder are trained separately.

スクリーンショット 2020-10-24 9 19 18

スクリーンショット 2020-10-24 9 20 15

However I found in your code that both transformers are called as self.bert().

https://github.com/chijames/Poly-Encoder/blob/master/encoder.py#L20-L27


Is it OK? I doubt these two encoders have different weights after training.

FYI: In the official implementation of BLINK(https://arxiv.org/pdf/1911.03814.pdf ) paper, they prepare different methods. https://github.com/facebookresearch/BLINK/blob/master/blink/biencoder/biencoder.py#L37-L48

Why not direct use Huggingface-BERT Pretrained Weights ?

Why do you convert the google-bert weight instead of directly using the bert weight of huggingface. Is there any performance difference between the two?

# converted weight from google-bert
bert = BertModelClass.from_pretrained(args.bert_model, state_dict=model_state_dict) 

# huggingface weight
bert = BertModelClass.from_pretrained('bert-base-uncased') 

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.