Coder Social home page Coder Social logo

dwzhu-pku / pose Goto Github PK

View Code? Open in Web Editor NEW
189.0 5.0 18.0 1.86 MB

Positional Skip-wise Training for Efficient Context Window Extension of LLMs to Extremely Length (ICLR 2024)

Home Page: https://arxiv.org/abs/2309.10400

License: MIT License

Shell 2.93% Python 97.07%

pose's Introduction

PoSE

This repository contains the code for the paper "PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training"

In this work, we introduce Positional Skip-wisE (PoSE) training for efficient adaptation of large language models~(LLMs) to extremely long context windows. PoSE decouples train length from target context window size by simulating long inputs using a fixed context window with manipulated position indices during training.

PoSE

Take context window extension from 2,048 to 8,192 as an example, we partition the original context window of 2,048 tokens into two chunks, and adjust the position indices of the second chunk by adding a distinct skipping bias term. These bias terms, as well as the length of each chunk, are altered for each training example, so that the model can adapt to all relative positions of the target context window through fine-tuning.

Notably, by decoupling fine-tuning length from target context window, PoSE can theoretically extend the context window infinitely, constrained only by memory usage for inference. With ongoing advancements for efficient inference~(e.g., vLLM, Flash Attention), we believe PoSE holds great promise for scaling the context window even further.

๐Ÿ”ฅ What's New

  • [2024/01/16] Our paper has been accepted by ICLR 2024 as a Poster.
  • [2023/10/26] We made the datasets used in this paper public for reproduction purpose.
  • [2023/10/11] We released our all our model checkpoints.
  • [2023/10/10] Updated our paper and code. Improved writing, added some discussion about chunk number and coverage possibility of relative positions in Appendix. Removed unused code and implemented a minor fix in train_preprocess_function_pose to make the coverage possiblity more uniform for large relative positions.
  • [2023/09/22] Inclued results of PoSE on Baichuan2. We further consolidated the effectiveness of our method.
  • [2023/09/19] Our paper and code were released.

โšก Checkpoints

Context Extended Versions of LLaMA (originally support 2k context)

Model Context Interpolation Link
LLaMA-7B-PoSE-Linear-16k 16,384 Linear download link
LLaMA-7B-PoSE-NTK-16k 16,384 NTK download link
LLaMA-7B-PoSE-YaRN-16k 16,384 YaRN download link
LLaMA-7B-PoSE-Linear-96k 98,304 Linear download link
LLaMA-7B-PoSE-YaRN-96k 98,304 YaRN download link
LLaMA-7B-PoSE-YaRN-128k 131,072 YaRN download link

Context Extended Versions of LLaMA2 (originally support 4k context)

Model Context Interpolation Link
LLaMA2-7B-PoSE-Linear-16k 16,384 Linear download link
LLaMA2-7B-PoSE-NTK-16k 16,384 NTK download link
LLaMA2-7B-PoSE-YaRN-16k 16,384 YaRN download link

Context Extended Versions of Baichuan2 (originally support 4k context)

Model Context Interpolation Link
Baichuan2-7B-PoSE-Linear-16k 16,384 Linear download link
baichuan2-7B-PoSE-NTK-16k 16,384 NTK download link
baichuan2-7B-PoSE-YaRN-16k 16,384 YaRN download link

๐Ÿ”ง Reproduction

To replicate our results, follow these steps to download the code and necessary dependencies:

git clone https://github.com/dwzhu-pku/PoSE.git
cd PoSE
pip install -r requirements.txt

Additionally, as we utilize lm-eval-harness for evaluation on standard benchmarks, please install lm-eval-harness under the helper/ folder.

As for the datasets, we have provided them in this link for reproduction purpose.

Data, Models and Computation Resources

We have conducted experiments with Llama-7B, Llama2-7B, and GPT-J-6B, Baichuan2-7B.

All the models are fine-tuned on The Pile dataset. Since this dataset is randomly shuffled, we use only the 00 split for training. We further filter short inputs and keep 10w samples for fine-tuning, which has proven sufficient for our method.

In terms of computation resources, all our training is conducted on 8 * 32G V100, and all evaluations are completed on a single A100.

Training and Evaluation

The scripts under script/ comprehensively cover the commands for training and evaluation.

For training, the key modifications revolve around position indices of the input text. You can refer to the train_preprocess_function_pose function to understand our proposed method. There are also minor revisions in my_modeling_xxx.py and my-configuration_xxx.py for implementing linear / NTK / YaRN interpolations and for utilizing xformers for efficient training & inference. Note that we use the revised version of YaRN in our experiments, as supported by the issue inv_freq seems not calculated right. For example, You can start training Llama for context extension from 2k to 128k (64x) with YaRN interpolation by running comments as follows:

cd PoSE
bash script/run_train_skipos.sh 64 yarn

For evaluation, we made no revisions to position indices, so the process remains the same as the common setting. You can run following comments for the evaluation of passkey retrieval / ppl / standard benchmarks:

cd PoSE
bash script/run_eval_passkey.sh # for passkey retrieval
bash script/run_eval_ppl.sh # for ppl
bash script/run_lm_eval.sh # for standard benchmarks

๐Ÿ“ˆ Experiment Results

Empirically, we demonstrate that PoSE achieves significant memory and time efficiency:

efficiency

It is compatible across various RoPE-based models and interpolation strategies:

widely_compatible

Capable of extending to 128k when combined with YaRN interpolation:

extremely_long

And it exhibits only minimal performance degradation on standard benchmarks:

standard

๐ŸŒŸ Citation

If you find this repo helpful, please cite our paper as follows:

@article{zhu2023pose,
  title={Pose: Efficient context window extension of llms via positional skip-wise training},
  author={Zhu, Dawei and Yang, Nan and Wang, Liang and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian},
  journal={arXiv preprint arXiv:2309.10400},
  year={2023}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.