Coder Social home page Coder Social logo

bert_sequence_tagger's Introduction

Description

BERT sequence tagger that accepts token list as an input (not BPE but any "general" tokenizer like NLTK or Standford) and produces tagged results in IOB format.

Basically, you can do:

from bert_sequence_tagger import BertSequenceTagger, ModelTrainerBert

seq_tagger = BertSequenceTagger(...) # initialize the model for training or load trained one.
# ... train model with ModelTrainerBert

seq_tagger.predict([['We', 'are', 'living', 'in', 'New', 'York', 'city', '.'],
                    ['Satya', 'Narayana', 'Nadella', 'is', 'an', 'engineer', 'and', 'business', 'executive', '.']])

Result:

([['O', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'O', 'O'],
  ['I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O']],
 [10.09477, 10.004749])

Training BERT model has many caveats that include but not limited to:

  • Proper masking of the input.
  • Proper padding of input.
  • Loss masking (masking loss of the padded tokens and loss of the BPE suffixes).
  • Adding proper special tokens like [CLS], [SEP] to the beginning and an end of a sequence.
  • Annealing of the learning rate, as well as properly handling the best models.
  • Proper calculation of the validation / training loss (taking into account masked tokens and masked loss elements).

Pytorch_transformers provides a good pytorch implementation of BertForTokenClassification, however, it lacks code for proper trainig of sequence tagging models. Noticable effort is required to convert a tokenized text into an input suitable for BERT, with which you can achieve SOTA.

This library does this work for you: it takes a tokenized input, performs bpe tokenization, padding, preparations, and all other work to prepare input for BERT. It also provides a trainer that can achieve the best performance for BERT models. See below example for CoNLL-2003 dataset. More detailed example in jupyter notebook is here.

Example

from bert_sequence_tagger import SequenceTaggerBert, BertForTokenClassificationCustom
from pytorch_transformers import BertTokenizer

from bert_sequence_tagger.bert_utils import get_model_parameters, prepare_flair_corpus
from bert_sequence_tagger.bert_utils import make_bert_tag_dict_from_flair_corpus 
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert
from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level

from pytorch_transformers import AdamW, WarmupLinearSchedule


import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger('sequence_tagger_bert')


# Loading corpus ############################

from flair.datasets import ColumnCorpus

data_folder = './conll2003'
corpus = ColumnCorpus(data_folder, 
                      {0 : 'text', 3 : 'ner'},
                      train_file='eng.train',
                      test_file='eng.testb',
                      dev_file='eng.testa')


# Creating model ############################

batch_size = 16
n_epochs = 4
model_type = 'bert-base-cased'
bpe_tokenizer = BertTokenizer.from_pretrained(model_type, do_lower_case=False)

idx2tag, tag2idx = make_bert_tag_dict_from_flair_corpus(corpus)

model = BertForTokenClassificationCustom.from_pretrained(model_type, 
                                                         num_labels=len(tag2idx)).cuda()

seq_tagger = SequenceTaggerBert(bert_model=model, bpe_tokenizer=bpe_tokenizer, 
                                idx2tag=idx2tag, tag2idx=tag2idx, max_len=128,
                                batch_size=batch_size)


# Training ############################

train_dataset = prepare_flair_corpus(corpus.train)
val_dataset = prepare_flair_corpus(corpus.dev)

optimizer = AdamW(get_model_parameters(model), lr=5e-5, betas=(0.9, 0.999), 
                  eps=1e-6, weight_decay=0.01, correct_bias=True)

n_iterations_per_epoch = len(corpus.train) / batch_size
n_steps = n_iterations_per_epoch * n_epochs
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, t_total=n_steps)

trainer = ModelTrainerBert(model=seq_tagger, 
                           optimizer=optimizer, 
                           lr_scheduler=lr_scheduler,
                           train_dataset=train_dataset, 
                           val_dataset=val_dataset,
                           validation_metrics=[f1_entity_level],
                           batch_size=batch_size)

trainer.train(epochs=n_epochs)

# Testing ############################

test_dataset = prepare_flair_corpus(corpus.test)
_, __, test_metrics = seq_tagger.predict(test_dataset, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
print(f'Entity-level f1: {test_metrics[1]}')
print(f'Token-level f1: {test_metrics[2]}')

# Predicting ############################
seq_tagger.predict([['We', 'are', 'living', 'in', 'New', 'York', 'city', '.']])

Installation

pip install git+https://github.com/IINemo/bert_sequence_tagger.git

Requirements

  • torch
  • tensorflow
  • pytorch_transformers
  • flair (optional for reading conll formatted files)
  • seqeval (optional for evaluation)
  • sklearn (optional for evaluation)

Cite

@inproceedings{shelmanov2019bibm,
    title={Active Learning with Deep Pre-trained Models for Sequence Tagging of Clinical and Biomedical Texts},
    author={Artem Shelmanov and Vadim Liventsev and Danil Kireev and Nikita Khromov and Alexander Panchenko and Irina Fedulova and Dmitry V. Dylov},
    booktitle={Proceedings of International Conference on Bioinformatics & Biomedicine (BIBM)},
    year={2019}
}

TODO

  • Remove dependency from tensorflow
  • Make ModelTrainer more generalizable

bert_sequence_tagger's People

Contributors

iinemo avatar shirayu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.