Coder Social home page Coder Social logo

rajk853 / drl-gec Goto Github PK

View Code? Open in Web Editor NEW
0.0 2.0 0.0 4.15 MB

Grammar Error Correction via Deep Reinforcement Learning

Home Page: https://rajk853.github.io/DRL-GEC/

Python 94.61% Shell 0.04% Jupyter Notebook 5.35%
bert-model deep-learning grammar-error-correction nlp pytorch reinforcement-learning sequence-labeling transformer

drl-gec's Introduction

Grammar Error Correction (GEC) using Deep Reinforcement Learning (DRL)

In this project, we fine-tune a Sequence-to-Label model derived from GECToR using DRL.

Rendering of GEC environment

The contributions of this project are as follows:

  1. A basic RL environment for GEC task.
  2. A batched-version of REINFORCE algorithm to optimize the GEC model.
  3. A simple action-search algorithm to allow safe exploration during RL training.

Setup

Conda

The command below will setup a conda environment with required packages and submodules.

git clone --recurse-submodules [email protected]:RajK853/DRL-GEC.git
cd DRL-GEC
conda env create -f environment.yml

Datasets

In this project, we used the following public GEC datasets:

  1. PIE Synthetic Dataset
  2. W&I+LOCNESS Dataset

Data Processing

To JSON format

The RL environment uses the datasets in the JSON format.

To process the W&I+LOCNESS dataset from M2 format into the JSON format, use the following command:

python m2_to_json.py --m2_path M2_PATH --json_path JSON_PATH [--min_len MIN_LEN] [--max_len MAX_LEN] [--min_sim MIN_SIM] [--only_proper_sent] [--spell_check] [--remove_ellipsis]

Description of the arguments are as follows:

--m2_path M2_PATH         # Path to the input M2 file
--json_path JSON_PATH     # Path to the output JSON files

# Optional Arguments
--min_len MIN_LEN       # Min number of tokens in original sentence
--max_len MAX_LEN       # Max number of tokens in original sentence
--min_sim MIN_SIM       # Min avg similarity between original and references
--only_proper_sent      # Allow only proper reference sentences
--spell_check           # Check spelling errors in original and references
--remove_ellipsis       # Remove (source) sentences with ellipsis

Use the notebooks/PIE_to_JSON.ipynb notebook to process the PIE Synthetic dataset into the JSON format. This notebook adapts the m2_to_json.py scripts to process parallel dataset into JSON format.

Please note that RL environment will look for a data.json file (data_filtered.json for only solvable examples) in the data/processed/dataset_name directory.

To GECToR format

We use the preprocessing script from GECToR to prepare datasets for Supervised Learning (SL) method.

Please change the following line in the gector/utils/preprocess_data.py beforehand.

Original:

from helpers import write_lines, ...

Modified:

from .helpers import write_lines, ...

To convert datasets into GECToR format, use the following command:

python json_to_gector.py --json_path JSON_PATH [--output_path OUTPUT_PATH] [--chunk_size CHUNK_SIZE]

Description of the arguments are as follows:

--json_path JSON_PATH      # Path to the input JSON file
--output_path OUTPUT_PATH  # Path to the output GECToR file

# Optional Arguments
--chunk_size CHUNK_SIZE    # Chunk size during processing

Only Solvable Examples

The Sequence-to-Label model cannot correct sentences if the required label is not present in the set of labels used by the model. Therefore, we have added a script to filter out unsolvable sentences from the JSON file using the following command:

python filter_unsolvable.py --json_path JSON_PATH [--label_path LABEL_PATH]

Description of the arguments are as follows:

--json_path JSON_PATH      # Path to the input JSON file

# Optional Arguments
--label_path LABEL_PATH    # Path to the label vocabulary

This will generate a filtered JSON file with the _filtered suffix i.e. data.json -> data_filtered.json.

Training

The models of our project uses the train_sl.py and train_rl.py scripts to train a model using SL and RL respectively. These scripts take training configurations in the YAML format. Configurations we used in our project to pre-train and fine-tune the models are present in the configs subdirectory. Use the following commands to pre-train and fine-tune the models using SL and RL.

python train_sl.py configs/sl_pretrain.yaml   # SL Pre-Training
python train_sl.py configs/sl_finetune.yaml   # SL Fine-Tuning
python train_rl.py configs/rl_finetune.yaml   # RL Fine-Tuning

Evaluation

In this project, we can evaluate models using following GEC benchmarks:

  1. CONLL-2014
  2. BEA-2019
  3. JFLEG

We have separate scripts to evaluate using each benchmark. The evaluate.py script will evaluate our models on all the above benchmarks.

python evaluate --model_dir MODEL_DIR [--label_path LABEL_PATH] [--max_iter MAX_ITER] [--force]

Description of the arguments are as follows:

--model_dir MODEL_DIR    # Path to directory with the trained models `model-best.pt` and `model-last.pt`

# Optional Arguments
--label_path LABEL_PATH  # Path to the label vocabulary
--max_iter MAX_ITER      # Max number of prediction iteration
--force                  # Override previous evaluation results

To evaluate on the BEA-2019, upload the zipped model outputs to their Codalab competition here.

drl-gec's People

Contributors

rajk853 avatar

Watchers

 avatar  avatar

drl-gec's Issues

Environment interaction interface

An interface to test our GEC environments.

  • State transitions (state, action, reward, new state, done)
  • Environment resetting
  • Environment rendering
  • Random action samples
  • Update reward weights
  • Update sampler weights
  • Save interaction

Implement DDQN algorithm

  • Create a simple Replay Buffer
  • Implement the Q-loss function
  • Implement target network weight updates

Random Sampler

Create a weighted random sampler that generates biased random actions.

Features:

  • Limit the number of non-keep labels in an action Done implicitly using the weights
  • Generate sample weights based on label weights

Create OpenAI Gym Environment

Text-based OpenAI Gym Environment

Data Format

A JSON file with source and list of target sentences.

  • Script to convert M2 into JSON
  • Data cleaning:
    • Normalize characters (both source and references)
    • Remove emojis M2 file does not have any emojis
    • Fix typos using Pyspellchecker (only source sentences. References are hopefully corrected!)
  • Data filtering:
    • Number of tokens in the original sentences
    • References ending properly
    • Mean similarity between the original and reference sentences
  • Generate metadata [Optional]
    • Number of sentences
    • Number of different references per sentence

Episode

  • Reset
    • Select the state
      • Source and target sentences
  • Take a step
    • Calculate the reward
      • Generate intermediate tokens
      • Compute GLEU scores with target sentences
      • Subtract \epsilon from the reward
    • Obtain the new state
    • Check episode termination
      • N-steps attempted
      • Token length condition
      • All keep actions condition
    • Return new state, reward and done

Reward function

$$r(s_t, a_t) = r_{gleu} + r_{delay} + r_{invalid-label}$$

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.