Coder Social home page Coder Social logo

plum-yin / composable-sft Goto Github PK

View Code? Open in Web Editor NEW

This project forked from cambridgeltl/composable-sft

0.0 0.0 0.0 151 KB

A library for parameter-efficient and composable transfer learning for NLP with sparse fine-tunings.

License: Other

Python 100.00%

composable-sft's Introduction

This is a library for training and applying sparse fine-tunings with torch and transformers. Please refer to our paper Composable Sparse Fine-Tuning for Cross Lingual Transfer for background.

News

2023/02/08

Released training and evaluation scripts for the NusaX sentiment classification task for Indonesian languages + pre-trained language and task SFTs - see MODELS for details of the test results, which in general exceed the baselines in the NusaX paper (though note that the training data and regime are different in our setup). Note that language SFTs for Buginese, Ngaju and Toba Batak are not yet available due to lack of suitable Wikipedia corpora.

2023/02/01

LotteryTicketSparseFineTuner can now be applied to Trainer subclasses such as QuestionAnsweringTrainer. See "Training SFTs" for further info.

MultiSourceDataset now supports upsampling - a different upsampling factor can be defined for each source language. E.g.

dataset = MultiSourceDataset(
    {
        'en': english_dataset,
        'swa': swahili_dataset,
    },
    upsampling={'swa': 2}
)

would create a MultiSourceDataset in which each example within swahili_dataset is replicated twice.

2022/01/13

composable-sft now supports multi-GPU training with DistributedDataParallel. This can be invoked in the same way as for a vanilla transformers Trainer.

2021/11/30

New utilities have been added for multi-source training (see Ansell et al. (2021)) and multi-source task SFTs have been released for some tasks. SFTs and examples scripts are now also available for question answering, and we have released language SFTs for many new languages (see MODELS). We recommend the use of multi-source task SFTs where available, as they are substantially better than the single-source SFTs for most languages.

Installation

First, install Python >= 3.9 and PyTorch >= 1.9 (earlier versions may work but haven't been tested), e.g. using conda:

conda create -n sft python=3.10
conda activate sft
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia

Then download and install composable-sft:

git clone https://github.com/cambridgeltl/composable-sft.git
cd composable-sft
pip install -e .

If you wish to use composable-sft in parallel with the adapter-transformers library, then install adapter-transformers first, otherwise the installer will automatically install non-adapter transformers.

Using pre-trained SFTs

Pre-trained SFTs can be downloaded directly and applied to models as follows:

from transformers import AutoConfig, AutoModelForTokenClassification
from sft import SFT

config = AutoConfig.from_pretrained(
    'bert-base-multilingual-cased',
    num_labels=17,
)

model = AutoModelForTokenClassification.from_pretrained(
    'bert-base-multilingual-cased',
    config=config,
)

language_sft = SFT('cambridgeltl/mbert-lang-sft-bxr-small') # SFT for Buryat
task_sft = SFT('cambridgeltl/mbert-task-sft-pos') # SFT for POS tagging

# Apply SFTs to pre-trained mBERT TokenClassification model
language_sft.apply(model)
task_sft.apply(model)

For a full list of pre-trained SFTs available, see MODELS

Training SFTs

LotteryTicketSparseFineTuner is a function which can be used to modify the transformers Trainer class and its subclasses to perform Lottery Ticket Sparse Fine-Tuning, e.g.

trainer_cls = LotteryTicketSparseFineTuner(QuestionAnsweringTrainer)

The constructor of the resulting trainer_cls class (which is itself a Trainer/QuestionAnsweringTrainer) subclass) takes the following arguments in addition to those of Trainer:

  • sft_args: an SftArguments object which holds hyperparameters relating to SFT training (c.f. transformers TrainingArguments).
  • maskable_params: a list of model parameter tensors which are eligible for sparse fine-tuning. Parameters of the classification head should be excluded from this list because these should typically be fully fine-tuned. E.g.
maskable_params = [
    n for n, p in model.named_parameters()
    if n.startswith(model.base_model_prefix) and p.requires_grad
]

The following command-line params processed by SftArguments may be useful:

  • ft_params_num/ft_params_proportion - controls the number/proportion of the maskable params that will be fine-tuned.
  • full_ft_max_steps_per_iteration/full_ft_max_epochs_per_iteration - controls the maximum number of steps/epochs in the first phase of LT-SFT. Both can be set.
  • sparse_ft_max_steps_per_iteration/sparse_ft_max_epochs_per_iteration - controls the maximum number of steps/epochs in the second phase of LT-SFT. Both can be set.
  • full_ft_min_steps_per_iteration/sparse_ft_min_steps_per_iteration - controls the minimum number of steps in the first/second phase of LT-SFT. Takes effect if a max number of epochs is set which amounts to a lesser number of steps.

Example Scripts

Examples of SFT training and evaluation are provided in examples/.

Multi-source Training

Multi-source training is where a task SFT is trained on data from several languages. We provide support for multi-source training with MultiSourceDataset and MultiSourcePlugin.

MultiSourceDataset combines data from several sources (e.g. different languages) into a single Dataset. Its constructor takes a dict mapping source names to Datasets, e.g.:

from sft import MultiSourceDataset

train_dataset = MultiSourceDataset({
    'en': english_train_dataset,
    'ja': japanese_train_dataset,
})

eval_dataset = MultiSourceDataset({
    'en': english_eval_dataset,
    'ja': japanese_eval_dataset,
})

where english_dataset and japanese_dataset are torch.utils.data.Datasets.

MultiSourcePlugin can be applied to a transformers Trainer (or subclass thereof) to allow it to be used in conjunction with MultiSourceDatasets:

from transformers import Trainer
from sft import SFT, LotteryTicketSparseFineTuner, MultiSourcePlugin

english_sft = SFT('cambridgeltl/mbert-lang-sft-en-small')
japanese_sft = SFT('cambridgeltl/mbert-lang-sft-ja-small')
source_sfts = {
    'en': english_sft,
    'ja': japanese_sft,
}

trainer_cls = Trainer
trainer_cls = LotteryTicketSparseFineTuner(trainer_cls)
trainer_cls = MultiSourcePlugin(trainer_cls)
trainer = trainer_cls(
    ..., # standard LotteryTicketSparseFineTuner parameters
    train_dataset=train_dataset, # instance of MultiSourceDataset
    eval_dataset=eval_dataset, # instance of MultiSourceDataset
    source_sfts=source_sfts,
)

Note the use of the optional argument source_sfts, a dict of source names to SFTs. If provided, the trainer will apply the SFT corresponding to the source language for each batch (note that each batch will consist of examples from only one source).

See the examples for further demonstration of multi-source training.

Citation

If you use this software, please cite the following paper (BibTeX key ansell-etal-2022-composable in the ACL Anthology):

Alan Ansell, Edoardo Ponti, Anna Korhonen, and Ivan Vulić. 2022. Composable Sparse Fine-Tuning for Cross-Lingual Transfer. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1778–1796, Dublin, Ireland. Association for Computational Linguistics.

composable-sft's People

Contributors

alanansell avatar parovicm avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.