Coder Social home page Coder Social logo

critic-aware-decoding's Introduction

Critic-Driven Decoding for Mitigating Hallucinations in Data-to-text Generation

Critic classifier training

If you want to use WebNLG data, you can download it with

python critics/dataset_generators/download-webnlg.py SPLIT_NAME

Generating data

  • ver 1. critic (base)
python critics/dataset_generators/gen_train_onlystop.py SPLIT_NAME

SPLIT_NAME is a placeholder for "train", "test", and "dev". To generate all necessary data, you should run the command three times i.e.

python critics/dataset_generators/gen_train_onlystop.py train
python critics/dataset_generators/gen_train_onlystop.py test
python critics/dataset_generators/gen_train_onlystop.py dev
  • ver 2. critic (base with full sentences)
python critics/dataset_generators/gen_train.py SPLIT_NAME 
  • ver 3. critic (vanilla LM)
python3 ./bin/decode.py \
    --model_name facebook/bart-base \
    --experiment webnlg \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data --load_in_8bit

where --model_name is the name of LM used to generate data from huggingface (here: facebook/bart-base)

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data
  • ver 4. critic (fine-tuned LM)

Put the checkpoint of fine-tuned language model into experiments/webnlg/CHECKPOINT_NAME path. Our BART-based LM model fine-tuned on WebNLG can be downloaded from https://we.tl/t-1aufs3tnyS

python3 ./bin/decode.py \
    --experiment webnlg \
    --checkpoint CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data --load_in_8bit

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data
  • ver 5. critic (fine-tuned LM with full sentences)
python3 ./bin/decode.py \
    --experiment webnlg \ 
    --checkpoint CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data-full --load_in_8bit

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data

Training the critic

Put the generated training data into OUT_DIR. The OUT_DIR directory should contain 3 files: train.csv, test.csv, and dev.csv with the training/test/validation data (these files should be generated by gen_train*.py scripts -- see above)

python critics/run.py --batch_size 32 --outdir OUT_DIR --model MLPSELU --lr 1e-5

Critic-aware decoding

Put the checkpoint of fine-tuned LM model into experiments/webnlg/CHECKPOINT_NAME path. Our BART-based LM model fine-tuned on WebNLG can be downloaded from here. The checkpoint of a trained critic should be located in CRITIC_CHECKPOINT_NAME. The name of the output file with the decoded text is specified by FILE_NAME.

python3 ./bin/decode.py \
    --experiment webnlg \
    --checkpoint LM_CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split test \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 0.25 \
    --critic_top_k 5 \
    --linear_warmup \
    --batch_size 8\
    --critic_ckpt CRITIC_CHECKPOINT_NAME \
    --out_filename FILE_NAME --wrapper classifier --load_in_8bit

critic-aware-decoding's People

Contributors

langus0 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.