Coder Social home page Coder Social logo

siat-nlp / galaxy Goto Github PK

View Code? Open in Web Editor NEW
105.0 2.0 12.0 336 KB

Official repository of the AAAI'2022 paper "GALAXY: A Generative Pre-trained Model for Task-Oriented Dialog with Semi-Supervised Learning and Explicit Policy Injection"

License: Apache License 2.0

Python 95.16% Shell 4.84%
pre-trained-model dialogue-generation task-oriented-dialogue semi-supervised

galaxy's Introduction

GALAXY

This repository contains code and data for the AAAI'2022 paper "GALAXY: A Generative Pre-trained Model for Task-Oriented Dialog with Semi-Supervised Learning and Explicit Policy Injection".

Full version with Appendix: [PDF]

Abstract

Pre-trained models have proved to be powerful in enhancing task-oriented dialog systems. However, current pre-training methods mainly focus on enhancing dialog understanding and generation tasks while neglecting the exploitation of dialog policy. In this paper, we propose GALAXY, a novel pre-trained dialog model that explicitly learns dialog policy from limited labeled dialogs and large-scale unlabeled dialog corpora via semi-supervised learning. Specifically, we introduce a dialog act prediction task for policy optimization during pre-training and employ a consistency regularization term to refine the learned representation with the help of unlabeled dialogs. We also implement a gating mechanism to weigh suitable unlabeled dialog samples. Empirical results show that GALAXY substantially improves the performance of task-oriented dialog systems, and achieves new state-of-the-art results on benchmark datasets: In-Car, MultiWOZ2.0 and MultiWOZ2.1, improving their end-to-end combined scores by 2.5, 5.3 and 5.5 points, respectively. We also show that GALAXY has a stronger few-shot ability than existing models under various low-resource settings.

Main Results

GALAXY perform end-to-end dialog modeling and achieve new state-of-the-art results on four TOD benchmark datasets: MultiWOZ2.0, MultiWOZ2.1, In-Car Assistant and CamRest.

End-to-End Modeling Inform Success BLEU Combined Score
MultiWOZ2.0 94.40 85.30 20.50 110.35
MultiWOZ2.1 95.30 86.20 20.01 110.76
End-to-End Modeling Match SuccF1 BLEU Combined Score
In-Car Assistant 85.26 83.60 23.03 107.46
CamRest 98.50 87.73 24.15 117.26

‼️ New SOTA results on MultiWOZ (End-to-End Modeling & Policy Optimization) evaluated by the standardized scoring scripts, which are officially recommended for the fair evaluations. We also add new results to the official leaderboard and new predictions to this repository.

MultiWOZ Inform Success BLEU Combined Score
End-to-End Modeling 85.40 75.70 19.64 100.2
Policy Optimization 92.80 83.50 19.92 108.1

Requirements

- torch == 1.8.0+cu111
- scikit-learn == 0.23.1
- numpy == 1.18.5
- nltk == 3.5
- spacy == 2.3.5
- scipy == 1.5.0
- regex == 2020.6.8
- tqdm == 4.60.0

We use the tokenization tool in SpaCy and you can directly install python packages by commands: pip install -r requirements.txt and python -m spacy download en_core_web_sm.

Preparation

Path Definition

Define your own paths <YOUR_PROJECT_PATH> and <YOUR_SAVE_PATH> in scripts as follows:

PROJECT_NAME="GALAXY"  # project name (fixed)
PROJECT_ROOT=<YOUR_PROJECT_PATH>/${PROJECT_NAME}  # root path of this project
SAVE_ROOT=<YOUR_SAVE_PATH>/${PROJECT_NAME}  # root path of model's output

Data Preparation

Download data from this link.

The downloaded zip file data.zip contains pre-training corpora and four TOD benchmark datasets: MultiWOZ2.0, MultiWOZ2.1, In-Car Assistant and CamRest, which have already been processed. You need to put the unzipped directory data into the project directory GALAXY for the subsequent training.

Pre-training

Pre-training Corpora

  • UniDA: a new labeled dialog dataset consisting of 975,780 utterances, which are annotated with 20 frequently-used DAs, according to our proposed comprehensive unified DA taxonomy for task-oriented dialog.
  • UnDial: a large-scale unlabeled dialog dataset consisting of 35M utterances with careful processing, ranging from online forum chatting logs to customer service conversations.

Pre-trained Checkpoint

  • GALAXY: an uncased model with DA classification head (12-layers, 768-hidden, 12-heads, 109M parameters)

You need to unzip the downloaded model file model.zip, then put the unzipped directory model into the project directory GALAXY for the futhuer fine-tuning.

Training

We pre-train the GALAXY on limited labeled dialogs (UniDA) and large-scale unlabeled dialog corpora (UnDial) via semi-supervised learning. You can pre-train GALAXY from scratch by running the following scripts:

# Step 1: Preprocess pre-training corpora
sh scripts/pre_train/preprocess.sh

# Step 2.1: Multi-GPU training on one machine
sh scripts/pre_train/train_single.sh

# Step 2.2: Multi-GPU training across multiple machines (distributed training)
sh scripts/pre_train/train_multi.sh

NOTE: For multi-GPU training, you only need to choose Step 2.1 or Step 2.2. It is worth noting that if you choose Step 2.2, you should have a well-equipped GPU cluster to support such training.

Fine-tuning

Fine-tuned Checkpoints

Download checkpoints from this link.

The downloaded zip file outputs.zip contains our best fine-tuned checkpoints on different datasets:

  • the 7-th epoch on MultiWOZ2.0 (60 training epochs in total)
  • the 5-th epoch on MultiWOZ2.1 (60 training epochs in total)
  • the 89-th epoch on In-Car Assistant (100 training epochs in total)
  • the 18-th epoch on CamRest (60 training epochs in total)

If you want to reproduce our reported results, you should put the unzipped directory outputs into the directory ${SAVE_ROOT} (set in scripts). Then you can directly run the inference scripts of different datasets for the reproduction, which will be introduced later.

Training

We fine-tune the GALAXY on the four TOD datasets and focus on the end-to-end dialog modeling (E2E) task. You can fine-tune GALAXY from scratch by running the following training scripts:

# Training on MultiWOZ2.0 (8 GPUs)
sh scripts/multiwoz2.0/train.sh

# Training on MultiWOZ2.1 (8 GPUs)
sh scripts/multiwoz2.1/train.sh

# Training on In-Car Assistant (1 GPU)
sh scripts/kvret/train.sh

# Training on CamRest (1 GPU)
sh scripts/camrest/train.sh

NOTE: For MultiWOZ2.0 and MultiWOZ2.1, we also maintain the DA prediction task to alleviate the model discrepancy between pre-training and fine-tuning. On the other hand, we discard this task on the In-Car Assistant and CamRest due to the lack of useful DAs in these two datasets. Besides, we support both multi-GPU and single-GPU training, you can jointly tune the hyper-parameter ${BATCH_SIZE}$ and ${GRADIENT_ACCUMULATION_STEPS}$ to maintain originally offered batch size when single-GPU training.

Inference

After collecting some fine-tuned checkpoints (by directly using ours or fine-tuning GALAXY from scratch by yourself), you can do the inference on the test sets of these datasets by running the following inference scripts:

# Inference on MultiWOZ2.0 (1 GPU)
sh scripts/multiwoz2.0/infer.sh

# Inference on MultiWOZ2.1 (1 GPU)
sh scripts/multiwoz2.1/infer.sh

# Inference on In-Car Assistant (1 GPU)
sh scripts/kvret/infer.sh

# Inference on CamRest (1 GPU)
sh scripts/camrest/infer.sh

NOTE: For reproduction, all the best hyper-parameters have already been set in corresponding scripts and you can follow them to run. If you fine-tune GALAXY from scratch by yourself, the 4-th/60 to 7-th/60 training epochs could offer you the best inference performance on MultiWOZ2.0/2.1.

References

  • For the implementation of UniLM architecture, we refer to the code of Pytorch-PLATO, which implements PLATO model in pytorch version.
  • For the data preparation and evaluation on MultiWOZ2.0/2.1, we refer to the code of UBAR.
  • For the data preparation and evaluation on In-Car Assistant/CamRest, we refer to the code of LABES.

Citation

If you use our code or find GALAXY useful in your work, please cite our paper as:

@article{he2022galaxy,
  title={GALAXY: A Generative Pre-trained Model for Task-Oriented Dialog with Semi-Supervised Learning and Explicit Policy Injection},
  author={He, Wanwei and Dai, Yinpei and Zheng, Yinhe and Wu, Yuchuan and Cao, Zheng and Liu, Dermot and Jiang, Peng and Yang, Min and Huang, Fei and Si, Luo and others},
  journal={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2022}
}

Contact

For personal communication related to GALAXY, please contact Wanwei He ([email protected]).

galaxy's People

Contributors

hwwancient avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

galaxy's Issues

Questions about the labeled`UniDA` with development/test sets.

In Table 1 and Readme, I found that among the 975780 utterances, some datasets such as MultiWOZ and SimJoint also use development set and test set during pre-training.

But the paper further evaluates on MultiWOZ test set. Would it make the evaluations unfair as the model already uses partial labeled information of MultiWOZ?

MultiWOZ 2.2 implementation (data processing)

Hi,
Thank you for releasing the code.

I want to run GALAXY on MultiWOZ 2.2, but there's no code for generating data_for_galaxy.json and generating data_for_galaxy_encoded.data.json.
Could you release the code for creating these files for MultiWOZ 2.2?

Thank you.

Reproducing GALAXY

Thank you for releasing the code to the public!
I am trying to reproduce the pre-train checkpoint you shared on Github, but I could not get the same checkpoint for somehow. So several questions came to my mind:
Q1: Stopping criteria for choosing pre-training & fine-tuning checkpoints. It seems to me that the stopping criteria is not based on validation loss. What was the criteria for choosing the final epoch number? For example, you said epoch 14 for pre-training and epoch 7 for MultiWOZ2.0. I wonder how you came up with the number.
Q2: The number of pre-training data. The UniDA dataset you shared on Github has 463,039, but this seems smaller than the sum of the training sets in eight datasets used for UniDA (according to the paper). Did you get the same checkpoint with the data you currently uploaded?
Q3: GPU machines used for pre-training. It would be great if you could share what GPU machines you used to pre-train the GALAXY checkpoint. I am guessing that might be one of the reasons why I do not get the same result. Thanks!

回复生成时,输入应当使用golden的上文system response还是模型生成的system response

在测试阶段,我注意到代码中生成system response的input中的context部分似乎是由模型前面自己生成的system response和数据集中的user utterance组成的。

而我在看multiwoz benchmark时,看了里面先前一些模型的代码,在测试阶段,input似乎使用的都是golden的system response而不是模型自己预测的,显然使用模型自己生成的语句作为context会模型降低效果。

先前我一直以为评测E2E-TOD 的system response使用的应当是模型自己预测的结果作为context,现在发现好像两种方式都有,十分困惑,也可能是我代码看的不够透彻,希望能得到解答,十分感谢!

FileNotFoundError

sh scripts/multiwoz2.0/infer.sh

params as below:

#!/bin/bash
set -ux

CUDA environment settings.

export CUDA_VISIBLE_DEVICES=0

Parameters.

DATA_NAME=multiwoz
PROJECT_NAME=GALAXY
MODEL=UnifiedTransformer
PROJECT_ROOT=/data/cll/${PROJECT_NAME}
SAVE_ROOT=/data/cll/${PROJECT_NAME}
VOCAB_PATH=${PROJECT_ROOT}/model/Bert/vocab.txt
VERSION=2.0
LOAD_MODEL_DIR=110-35
LOAD_MODEL_NAME=state_epoch_7
INIT_CHECKPOINT=${SAVE_ROOT}/outputs/${DATA_NAME}${VERSION}/${LOAD_MODEL_DIR}/${LOAD_MODEL_NAME}
WITH_JOINT_ACT=false
USE_TRUE_PREV_BSPN=false
USE_TRUE_PREV_ASPN=false
USE_TRUE_PREV_RESP=false
USE_TRUE_CURR_BSPN=false
USE_TRUE_CURR_ASPN=false
USE_TRUE_DB_POINTER=false
USE_ALL_PREVIOUS_CONTEXT=true
BATCH_SIZE=1
BEAM_SIZE=1
NUM_GPU=1
SEED=10
SAVE_DIR=${SAVE_ROOT}/outputs/${DATA_NAME}${VERSION}/${LOAD_MODEL_DIR}.infer

Main run.

python -u run.py
--do_infer=true
--model=${MODEL}
--save_dir=${SAVE_DIR}
--data_name=${DATA_NAME}
--data_root=${PROJECT_ROOT}
--vocab_path=${VOCAB_PATH}
--init_checkpoint=${INIT_CHECKPOINT}
--with_joint_act=${WITH_JOINT_ACT}
--use_true_prev_bspn=${USE_TRUE_PREV_BSPN}
--use_true_prev_aspn=${USE_TRUE_PREV_ASPN}
--use_true_prev_resp=${USE_TRUE_PREV_RESP}
--use_true_curr_bspn=${USE_TRUE_CURR_BSPN}
--use_true_curr_aspn=${USE_TRUE_CURR_ASPN}
--use_true_db_pointer=${USE_TRUE_DB_POINTER}
--use_all_previous_context=${USE_ALL_PREVIOUS_CONTEXT}
--batch_size=${BATCH_SIZE}
--beam_size=${BEAM_SIZE}
--version=${VERSION}
--gpu=${NUM_GPU}
--seed=${SEED}
--max_len=1024
--max_ctx_turn=20
--num_act=20
--num_type_embeddings=2
--data_processed=data_for_galaxy_encoded.data.json

error:
Traceback (most recent call last):
File "run.py", line 130, in
main()
File "run.py", line 74, in main
bpe = MultiWOZBPETextField(hparams)
File "/data/cll/GALAXY/galaxy/data/field.py", line 356, in init
self._build_vocab()
File "/data/cll/GALAXY/galaxy/data/field.py", line 498, in _build_vocab
self.vocab.load_vocab(vp)
File "/data/cll/GALAXY/galaxy/utils/utils.py", line 199, in load_vocab
self._freq_dict = json.loads(open(vocab_path + '.freq.json', 'r').read())
FileNotFoundError: [Errno 2] No such file or directory: '/data/cll/GALAXY/data/multiwoz2.0/vocab.freq.json'

In project document GALAXY have no "data/multiwoz2.0/vocab.freq.json", where can I get this file? Same error happened in ”scripts/camrest/infer.sh“:FileNotFoundError: [Errno 2] No such file or directory: 'data/camrest/CamRestOTGY.json'

Two questions about the evaluation

Hi,

Great thanks for providing this fantastic repo!

I have two questions about the evaluations:

  1. How many random seeds did you use to get the main evaluation results on the MultiWOZ2.0 dataset, e.g., Table 3 in your AAAI paper? If more than one seed is used, what are the other seeds except the SEED=10 in GALAXY/scripts/multiwoz2.0/train.sh?
  2. In GALAXY/scripts/multiwoz2.0/infer.sh there is a command LOAD_MODEL_NAME=state_epoch_7. May I ask how you select this checkpoint (the 7-th/60 training epochs)? Is there a way that we can automatically select the best checkpoint?

Looking forward to hearing from you!

Question about dynamic booking pointer during dialogue generation

I am interested in coding a little demo with the pretrained multiwoz model. However I am not able to figure out how to inject book info into db pointer dynamically. When should the model trigger to check if booking is possible or not? Let's say in a real world scenario if the predicted dialogue act is [offerbook] do we then check to see if booking is possible or not? It feels like we also need the user act here - something like [accept book] or [reject book] and only after then if predicted user act is [accept book] then system should check whether a booking is possible or not and add that result to db pointer.

Here we see that ground truth book pointer is used. What should be the process and sequence of actions to get the book pointer in a real world scenario. I've tried keeping it as [book_nores] but this causes the dialogue to go into a loop asking whether user would like to do booking or not. One solution is to change it to [book_success] once predicted dialogue act has [offerbook] but in the case of user utterance "no i changed my mind" system still outputs something like "booking was successful." because [book success] was added to db and user's preference didn't have an affect.

I have a very dirty kaggle notebook, you may take a look at my attempt under Inference section.

About domain overlap in the dataset

Hello!
I found that dataset Multiwoz has already been included in your pretrained data UniDA, while it is also used as your fine-tuning data. Will there be unfairness in your low-resource experiment result?
Thanks for your response in advance!

How to obtain delexicalized representations?

@HwwAncient Hello, thanks for your work!

I'm referring to the downloaded MultiWOZ data. In data_for_galaxy.json, terms user_delex and resp are delexicalized responses. I have following questions:

  1. How are they generated?
  2. How are they used in training and evaluation?

Question for codes and dataset

Thank you for opening codes for your impressive methods and results.

I'd like to ask when your code and dataset will be released.

bug in the code

When we run pretrain_trainer.py file with some no of batches (let's say 16), instead of running for 16 batches, it runs for all available batches (2372275/32 = 74133 batches). Though we load only 16 batches using DataLoader object in data_loader.py file. I could not find the actual reason behind this. Plz help to resolve this issue

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.