Coder Social home page Coder Social logo

learning-to-unjumble's Introduction

learning to unjumble as a pretraining objective for RoBERTa

GLUE Results Using RoBERTa with Jumbled Token Discrimination Loss

jumbling probability = 0.15, peak lr = 1e-4, steps = 1000

jumbling probability = 0.30, peak lr = 1e-4, steps = 1000

GLUE Results Using RoBERTa with Masked Token Discrimination Loss [BASELINE]

masking probability = 0.15, peak lr = 1e-4, steps = 1000

Additional GLUE Results

Pretrained RoBERTa

Pretrained ELECTRA

Training and Evaluation Command Lines

Train with MLM Loss

# make sure transformers version is 2.7.0
!pip install transformers==2.7.0

!cd ./unjumble

"""
# use wikidump data or wikitext data
"""
# download data
from torchtext.datasets import WikiText103
WikiText103.download('./data')

# run roberta training with MLM loss
python run_language_modeling.py \
--output_dir ./models/roberta_mlm \
--model_type roberta \
--mlm \
--do_train \
--do_eval \
--save_steps 2000 \
--per_gpu_train_batch_size 8 \
--evaluate_during_training \
--train_data_file data/wikitext-103/wikitext-103/wiki.train.tokens \
--line_by_line \
--eval_data_file data/wikitext-103/wikitext-103/wiki.test.tokens \
--model_name_or_path roberta-base

Train with jumbled token discrimination loss

# make sure transformers version is 2.7.0
!pip install transformers==2.7.0

!cd ./unjumble

# download data
TRAIN_DATA_PATH=../../data/wikidump/train.txt
VAL_DATA_PATH=../../data/wikidump/val.txt

# run roberta training with jumbled token-modification-discrimination head
!python run_language_modeling.py \
--output_dir ../models/roberta_token_discrimination \
--tensorboard_log_dir ../tb/roberta_token_discrimination \
--model_type roberta \
--model_name_or_path roberta-base \
--token_discrimination \
--do_train \
--gradient_accumulation_steps 64 \
--save_steps 50 \
--max_steps 1000 \
--weight_decay 0.01 \
--warmup_steps 100 \
--learning_rate 5e-5 \
--per_gpu_train_batch_size 16 \
--per_gpu_eval_batch_size 16 \
--train_data_file $TRAIN_DATA_PATH \
--eval_data_file $VAL_DATA_PATH \
--jumble_probability 0.15 \
--line_by_line \
--logging_steps 1 \
--do_eval \
--eval_all_checkpoints

Train with POS based jumbled token discrimination loss

# make sure transformers version is 2.7.0
!pip install transformers==2.7.0

!cd ./unjumble

# download data
TRAIN_DATA_PATH=../../data/wikidump/train.txt
VAL_DATA_PATH=../../data/wikidump/val.txt

# run roberta training with jumbled token-modification-discrimination head
!python run_language_modeling.py \
--output_dir ../models/roberta_token_discrimination \
--tensorboard_log_dir ../tb/roberta_token_discrimination \
--model_type roberta \
--model_name_or_path roberta-base \
--token_discrimination \
--pos \  # perform POS based jumbling (only Nouns and Adjectives are jumbled)
--do_train \
--gradient_accumulation_steps 64 \
--save_steps 50 \
--max_steps 1000 \
--weight_decay 0.01 \
--warmup_steps 100 \
--learning_rate 5e-5 \
--per_gpu_train_batch_size 16 \
--per_gpu_eval_batch_size 16 \
--train_data_file $TRAIN_DATA_PATH \
--eval_data_file $VAL_DATA_PATH \
--jumble_probability 0.15 \
--line_by_line \
--logging_steps 1 \
--do_eval \
--eval_all_checkpoints

Train with masked token discrimination loss

# make sure transformers version is 2.7.0
!pip install transformers==2.7.0

!cd ./unjumble

# download data
TRAIN_DATA_PATH=../../data/wikidump/train.txt
VAL_DATA_PATH=../../data/wikidump/val.txt

# run roberta training with masked token-modification-discrimination head
!python run_language_modeling.py \
--output_dir ../models/roberta_MASK_token_discrimination \
--tensorboard_log_dir ../tb/roberta_MASK_token_discrimination \
--model_type roberta \
--model_name_or_path roberta-base \
--mask_token_discrimination \  # NOTE THIS AND..
--do_train \
--gradient_accumulation_steps 64 \
--save_steps 50 \
--max_steps 1000 \
--weight_decay 0.01 \
--warmup_steps 100 \
--learning_rate 5e-5 \
--per_gpu_train_batch_size 16 \
--per_gpu_eval_batch_size 16 \
--train_data_file $TRAIN_DATA_PATH \
--eval_data_file $VAL_DATA_PATH \
--mask_probability 0.15 \  # ..THIS
--line_by_line \
--logging_steps 1 \
--do_eval \
--eval_all_checkpoints

Running on Prince

# Load these modules every time you log in
module purge
module load anaconda3/5.3.1
module load cuda/10.0.130
module load gcc/6.3.0

# Activate your environment

NETID=aa7513

source activate /scratch/${NETID}/nlu_projects/env

# Git pull/clone the repo

cd /scratch/${NETID}/

sbatch run_training.sbatch

Compute GLUE scores

# make sure transformers version is 2.8.0
!pip install transformers==2.8.0

cd ./compute_glue_scores

GLUE_DIR=../data/glue
TASK_NAME=QNLI  # specify GLUE task

# download GLUE data
!python download_glue_data.py --data_dir $GLUE_DIR --tasks $TASK_NAME

# specify the model directory
# the model directory may be a checkpoint directory
# it should contain config.json, merges.txt, pytorch_model.bin, special_tokens_map.json, tokenizer_config.json, training_args.bin, vocab.json
# it SHOULD NOT contain optimizer.pt and scheduler.pt
MODEL_DIR=../models/roberta_token_discrimination

OUTPUT_DIR=../models/glue/$TASK_NAME

# run glue
!python run_glue.py \
    --model_type roberta \
    --model_name_or_path $MODEL_DIR \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_gpu_eval_batch_size=64   \
    --per_gpu_train_batch_size=64   \
    --learning_rate 2e-5 \
    --num_train_epochs 3 \
    --output_dir $OUTPUT_DIR \
    --overwrite_output_dir

learning-to-unjumble's People

Contributors

aa7513 avatar bbloch18 avatar hjw9673 avatar subhadarship avatar

Stargazers

 avatar

Watchers

 avatar

Forkers

hjw9673

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.