Coder Social home page Coder Social logo

chia-hsuan-lee / dst-as-prompting Goto Github PK

View Code? Open in Web Editor NEW
61.0 6.0 12.0 12.31 MB

Source code for Dialogue State Tracking with a Language Model using Schema-Driven Prompting

Shell 0.23% Makefile 0.01% Dockerfile 0.03% Jsonnet 0.01% Python 94.12% Jupyter Notebook 0.73% C 0.02% C++ 0.02% Cuda 0.18% MDX 4.66%
dialogue dialogue-state-tracking natural-language-processing prompt-tuning prompting schema seq2seq t5 task-oriented-dialogue

dst-as-prompting's Introduction

SDP-DST: DST-as-Prompting

This is the original implementation of "Dialogue State Tracking with a Language Model using Schema-Driven Prompting" by Chia-Hsuan Lee, Hao Cheng and Mari Ostendorf.

Installation | Preprocess | Training | Evaluation | | Citation

Installation

conda create -n DST-prompt python=3.7
cd DST-as-Prompting
conda env update -n DST-prompt -f env.yml

To use Hugggingface seq2seq training scripts, install from source.

pip install git+https://github.com/huggingface/transformers.git@2c2a31ffbcfe03339b1721348781aac4fc05bc5e

Pip install requirements to use Huggingface training script

cd transformers/examples/pytorch/summarization/
pip install -r requirements.txt

Download and Preprocess Data

Please download the data from MultiWOZ github.

cd ~/DST-as-Prompting
git clone https://github.com/budzianowski/multiwoz.git

$DATA_DIR will be multiwoz/data/MultiWOZ_2.2

cd ~/DST-as-Prompting
python preprocess.py $DATA_DIR

Training

cd transformers
python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_predict \
    --train_file "$DATA_DIR/train.json" \
    --validation_file "$DATA_DIR/dev.json" \
    --test_file "$DATA_DIR/test.json" \
    --source_prefix "" \
    --output_dir /tmp/t5small_mwoz2.2 \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
    --text_column="dialogue" \
    --summary_column="state" \
    --save_steps=500000
  • --model_name_or_path: name of the model card, like t5-small, t5-base, etc

This should take ~32 hours to train on a single GPU. At the end of training, the model will get predictions on $test_file and store the results at $output_dir/generated_predictions.txt .

Evaluation

cd ~/DST-as-Prompting

python postprocess.py --data_dir "$DATA_DIR" --out_dir "$DATA_DIR/dummy/" --test_idx "$DATA_DIR/test.idx" \
    --prediction_txt "$output_dir/generated_predictions.txt"

python eval.py --data_dir "$DATA_DIR" --prediction_dir "$DATA_DIR/dummy/" \
    --output_metric_file "$DATA_DIR/dummy/prediction_score"

Citation and Contact

If you find our code or paper useful, please cite the paper:

@inproceedings{lee2021dialogue,
  title={Dialogue State Tracking with a Language Model using Schema-Driven Prompting},
  author={Lee, Chia-Hsuan and Cheng, Hao and Ostendorf, Mari},
  booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing},
  pages={4937--4949},
  year={2021}
}

Please contact Chia-Hsuan Lee (chiahlee[at]uw.edu) for questions and suggestions.

dst-as-prompting's People

Contributors

chia-hsuan-lee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

dst-as-prompting's Issues

Hope for the code

Hi,

I think this is a very interesting work which is both simple and efficient !

So I truly want the code for further study.

Best for you

On reproducing the experiment results in paper

Hi,

Congrats on being accepted in EMNLP 2021 as a concise and solid work! I am currently following your research and trying to reproduce the experimental results in the original paper using your codes. However, I have met some trouble in aligning the same JGA scores.

My experiments were all on MultiWOZ v2.2, with domain and slot descriptions. Here are my hyperparameter settings and corresponding results.

  • T5-small, lr=5e-5, n_epoch=3, batchsize=8, JGA=55.3
  • T5-base, lr=5e-5, n_epoch=3, batchsize=8, JGA=56.0
  • T5-base, lr=5e-4, n_epoch=2, batchsize=16(bs=8 with grad_accumulation=2), JGA=56.1
  • T5-base, lr=5e-4, n_epoch=2, batchsize=64(bs=8 with grad_accumulation=8), JGA=56.2 [Same as paper]
    The experiments were implemented on a single A100 40GB, with Python==3.9.12, PyTorch==1.12.1, CUDA==11.6, and the other hyperparameters remained default. There is still a gap between my results and the JGA score on paper, which is 57.6.

I am wondering if there is some other tricks to achieve a better results. If so, is it okay to share? So much appreciated!
Looking forward to your reply :-D

Best

Is there custom `Schema.json` for MultiWOZ?

Hi,

I am trying to use eval.py, but I am experiencing an error that the file cannot be found. (example below)
FileNotFoundError: [Errno 2] No such file or directory: 'multiwoz/data/MultiWOZ_2.2/test/schema.json'

MultiWOZ official git does not have the schema.json in data/MultiWOZ_2.2/test, data/MultiWOZ_2.2/train
So I wonder if you use a customized Schema.json files for test and train dataset.

If So, could you share the customized schema file?
Or just brief Schema information will help.

I am grateful for your great work.
I look forward to hearing from you

Regards,
Yeseul

To confirm 2 questions made me confused in the paper.

Hi,

I think this is a very interesting work and I have two questions want to check:

  1. For Schema-Based Prompt DST w/ Independent Decoding, during the inference stage, does the model predict the domain, slot, and value in parallel or in order? I'm confused about whether the T5 needs to predict 8 times for 8 domains in each sample. Then, it predicts the slots in each domain. Finally, T5 has to predict the slot value many many times? How to get all this in one go?
  2. The case with desc impresses me a lot, where the model can modify the 4:45 PM to 16:45, that's quite amazing, how the model makes it?
    image

Looking forward to your reply.

Best

Questions on preprocessing data

Hi,
Thanks for your great work!

Can you offer the code for preprocess data on MultiWOZ2.1?

Looking forward to your reply.
Best

Training script for t5-base

Hi, thank you for the nice code. It works fine with t5-small.
I also follow the settings for training t5-base in your paper, but the model seems to be not properly trained. The loss when evaluation is much higher than t5-small, and the prediction results are also terrible. I think it is because the hyperparameters I set are still not correct. Can you also provide your script for training on T5-base? Thank you!

This is the script I am using:
CUDA_VISIBLE_DEVICES=0,1 python examples/pytorch/summarization/run_summarization.py
--model_name_or_path google/t5-base
--do_train
--do_predict
--train_file "$DATA_DIR/train.json"
--validation_file "$DATA_DIR/dev.json"
--test_file "$DATA_DIR/test.json"
--source_prefix ""
--output_dir "$OUTPUT_DIR/t5-base-mwoz2.2"
--per_device_train_batch_size=4
--per_device_eval_batch_size=4
--gradient_accumulation_steps 8
--predict_with_generate
--learning_rate 5e-4
--num_train_epochs 2
--text_column="dialogue"
--summary_column="state"
--save_steps=25000

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.