Coder Social home page Coder Social logo

cbart's Introduction

README for CBART

This repository contains the implementation of the EMNLP 2021 paper: "Parallel Refinements for Lexically Constrained Text Generation with BART".


Abstract

Lexically constrained text generation aims to control the generated text by incorporating some pre-specified keywords into the output. Previous work injects lexical constraints into the output by controlling the decoding process or refining the candidate output iteratively, which tends to generate generic or ungrammatical sentences, and has high computational complexity. To address these challenges, we propose Constrained BART (CBART) for lexically constrained text generation. CBART leverages the pre-trained model BART and transfers part of the generation burden from the decoder to the encoder by decomposing this task into two sub-tasks, thereby improving the sentence quality. Concretely, we extend BART by adding a token-level classifier over the encoder, aiming at instructing the decoder where to replace and insert. Guided by the encoder, the decoder refines multiple tokens of the input in one step by inserting tokens before specific positions and re-predicting tokens with low confidence. To further reduce the inference latency, the decoder predicts all tokens in parallel. Experiment results on One-BillionWord and Yelp show that CBART can generate plausible text with high quality and diversity while significantly accelerating inference.


Requirements

python 3.6
pip install torch==1.4.0
pip install transformers==3.0.2
pip install pympler==0.8


Dataset

All our experiments are conducted on One-Billion-Word and Yelp review corpora. In this paper, we choose 1M, 0.1M sentences from each dataset as the training and validation sets (The full data used in this paper are available at https://drive.google.com/drive/folders/1Dj7VX2CjSn3-g7FEYuJrT5_JWGdsAHjE?usp=sharing). If you want to train the model from scratch, you should download the corresponding data first and put them in the corresponding directory, i.e. data/one-billion-words (data/yelp_review). Note we only put several sentences in the data/one-billion-words/train.txt and data/one-billion-words/dev.txt.


Try our model with well-trained model checkpoints

Model Download link
CBART-base for Yelp review [link]
CBART-large for Yelp review [link]
CBART-base for One-Billion-Word [link]
CBART-large for One-Billion-Word [link]

If you want to try our models, you should download these checkpoints, put them into the 'checkpoints' directory, and decompress them with the following command: Then you can directly go to Generate sentences with lexical constraints.

tar -xzvf checkpoint_name.tar.gz # replace 'checkpoint_name' with the corresponding checkpoint name.

If you want to train our model on another dataset, please refer to the following steps.


Train our model from scratch

Note the default dataset is One-Billion-Word. You can freely change it to another dataset.

  • Step 1: Create synthetic data to train CBART
cd utils  
sh create_synthetic_data.sh
  • Step 2: Train CBART
cd models

If you want to train CBART-base on One-Billion-Word:

python bart.py --batch_size 80 --gpu 5 --dataset one-billion-words

If you want to train CBART-large on One-Billion-Word:

python bart.py --batch_size 25 --gpu 5 --dataset one-billion-words --bart large

Generate sentences with lexical constraints

We show some keywords in "data/one-billion-words/4keywords.txt", where each line has 4 keywords. In the following, we'll generate sentences with 4 keywords. If you want to generate sentences with other number of keywords, you should prepare keywords and put them in the "data/dataset_name/{k}keywords.txt", where '{k}' denotes the number of keywords in each line. If so, you need to change the hyperparameter "num_keywords" (e.g., --num_keywords 1, if you want to generate sentence with one keyword).

Generate sentences with 4keywords.txt by running greedy decoding on CBART-base:

python main.py --gpu 7 --num_keywords 4 --do_sample 0 --batch_size 10 --bart base --dataset one-billion-words

Generate sentences with 4keywords.txt by running multiple-sequence decoding (p=0.5, c=5 ) decoding on CBART-base:

python main.py --gpu 7 --num_keywords 4 --do_sample 1 --top_p 0.5 --decoder_chain 5 --batch_size 10 --bart base --dataset one-billion-words

Generate sentences with 4keywords.txt by running multiple-sequence decoding (k=5, c=5) decoding on CBART-base:

python main.py --gpu 7 --num_keywords 4 --do_sample 1 --top_k 5 --decoder_chain 5 --batch_size 10 --bart base --dataset one-billion-words

Citation

If you want to use this code in your research, you can cite our paper:

@inproceedings{he-2021-parallel,
    title = "Parallel Refinements for Lexically Constrained Text Generation with {BART}",
    author = "He, Xingwei",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    month = nov,
    year = "2021",
    address = "Online and Punta Cana, Dominican Republic",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.emnlp-main.681",
    doi = "10.18653/v1/2021.emnlp-main.681",
    pages = "8653--8666",
}
}

cbart's People

Contributors

nlpcode avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

wangxuekui

cbart's Issues

关于replace label

训练阶段构造的数据有0,1,2的分类,其中1表示替换,而论文中说推理阶段不需要替换,但是在推理时构造indicate_labels += [1] + [0] * (len(ids) - 1)这里添加进很多1,是为什么呢

How did you extract the keywords??

How did you extract the keywords from the dataset(yelp, one billion words)?
I wonder if it's manual or if there's a python library or model you used.

关于合成数据create_synthetic_data.py

您好,我正在学习这份代码。想请问一下在这份创建合成数据的代码中,获取的3个list:incorrect_input_ids_list, label_ids_list, target_ids_list分别代表什么含义?对应论文中的哪里呢?谢谢

About full mask

If the full mask equals to 1, the decoder can be seen as bert, right?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.