Coder Social home page Coder Social logo

tod-bert's Introduction

ToD-BERT: Pre-trained Natural Language Understanding for Task-Oriented Dialogues

Authors: Chien-Sheng Wu, Steven Hoi, Richard Socher and Caiming Xiong.

Paper: https://arxiv.org/abs/2004.06871

Introduction

The underlying difference of linguistic patterns between general text and task-oriented dialogue makes existing pre-trained language models less effective in practice. In this work, we unify nine human-human and multi-turn task-oriented dialogue datasets for language modeling. To better model dialogue behavior during pre-training, we incorporate user and system special tokens into the masked language modeling, and we add a contrastive objective function with a simulated response selection task. Our pre-trained task-oriented dialogue BERT (TOD-BERT) outperforms strong baselines like BERT in four downstream task-oriented dialogue applications, including intention detection, dialogue state tracking, dialogue act prediction, and response selection. We also show that TOD-BERT has stronger few-shot ability that can mitigate the data scarcity problem for task-oriented dialogue.

Citation

If you use any source codes, pretrained models or datasets included in this repo in your work, please cite the following paper. The bibtex is listed below:

@article{wu2020tod,
  title={ToD-BERT: Pre-trained Natural Language Understanding for Task-Oriented Dialogues},
  author={Wu, Chien-Sheng and Hoi, Steven and Socher, Richard and Xiong, Caiming},
  journal={arXiv preprint arXiv:2004.06871},
  year={2020}
}

Update

  • (2020.07.10) Loading model from Huggingface is now supported.
  • (2020.04.26) Pre-trained models are available.

Pretrained Models

You can easily load the pre-trained model using huggingface Transformer library using the AutoModel function. Several pre-trained versions are supported:

  • TODBERT/TOD-BERT-MLM-V1: TOD-BERT pre-trained only using the MLM objective
  • TODBERT/TOD-BERT-JNT-V1: TOD-BERT pre-trained using both the MLM and RCL objectives
import torch
from transformers import *
tokenizer = AutoTokenizer.from_pretrained("TODBERT/TOD-BERT-JNT-V1")
tod_bert = AutoModel.from_pretrained("TODBERT/TOD-BERT-JNT-V1")

You can also downloaded the pre-trained models from the following links:

model_name_or_path = <path_to_the_downloaded_tod-bert>
model_class, tokenizer_class, config_class = BertModel, BertTokenizer, BertConfig
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
tod_bert = model_class.from_pretrained(model_name_or_path)

Direct Usage

Please refer to the following guide how to use our pre-trained ToD-BERT models. Full training and evaluation code will be released soon. Our model is built on top of the PyTorch library and huggingface Transformer library. Let's do a very quick overview of the model architecture and code. Detailed examples for model architecturecan be found in the paper.

# Encode text 
input_text = "[CLS] [SYS] Hello, what can I help with you today? [USR] Find me a cheap restaurant nearby the north town."
input_tokens = tokenizer.tokenize(input_text)
story = torch.Tensor(tokenizer.convert_tokens_to_ids(input_tokens)).long()

if len(story.size()) == 1: 
    story = story.unsqueeze(0) # batch size dimension

if torch.cuda.is_available(): 
    tod_bert = tod_bert.cuda()
    story = story.cuda()

with torch.no_grad():
    input_context = {"input_ids": story, "attention_mask": (story > 0).long()}
    hiddens = tod_bert(**input_context)[0] 

Training and Testing

If you would like to train the model yourself, you can download those datasets yourself from each of their original papers or sources. You can also direct download a zip file here.

The repository is currently in this structure:

.
└── image
    └── ...
└── models
    └── multi_class_classifier.py
    └── multi_label_classifier.py
    └── BERT_DST_Picklist.py
    └── dual_encoder_ranking.py
└── utils.py
    └── multiwoz
        └── ...
    └── metrics
        └── ...
    └── loss_function
        └── ...
    └── dataloader_nlu.py
    └── dataloader_dst.py
    └── dataloader_dm.py
    └── dataloader_nlg.py
    └── dataloader_usdl.py
    └── ...
└── README.md
└── evaluation_pipeline.sh
└── evaluation_ratio_pipeline.sh
└── run_tod_lm_pretraining.sh
└── main.py
└── my_tod_pretraining.py
  • Run Pretraining
❱❱❱ ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-MLM --only_last_turn
❱❱❱ ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-JNT --only_last_turn --add_rs_loss
  • Run Fine-tuning
❱❱❱ ./evaluation_pipeline.sh 0 bert bert-base-uncased save/BERT
  • Run Fine-tuning (Few-Shot)
❱❱❱ ./evaluation_ratio_pipeline.sh 0 bert bert-base-uncased save/BERT --nb_runs=3 

Report

Feel free to create an issue or send email to the first author at [email protected]

tod-bert's People

Contributors

jasonwu0731 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.