Coder Social home page Coder Social logo

damo-nlp-sg / clex Goto Github PK

View Code? Open in Web Editor NEW
70.0 4.0 11.0 100 KB

[ICLR 2024] CLEX: Continuous Length Extrapolation for Large Language Models

License: MIT License

Python 99.22% Shell 0.78%
large-language-models rotary-position-embedding long-context-modeling

clex's Introduction

CLEX: Continuous Length Extrapolation for Large Language Models

This repo provides the official implementation of our paper "CLEX: Continuous Length Extrapolation for Large Language Models"

News

  • [2024.1.19] 🔥 Release the CLEX-Mixtral-8x7B-32K, CLEX-LLaMA-2-7B-64K, and CLEX-Phi-2-7B-32K (and refactor the codes to support different models), which all support more than 100k context length!
  • [2024.1.16] 🌟 CLEX has been accepted to ICLR 2024!
  • [2023.10.25] 🚀 Release the code of CLEX and the long-context base & chat models trained with CLEX.

Features and Highlights of CLEX

CLEX_diagram

  • Simple and Clear: MINIMAL code and architecture changes. Only one up-and-down projection layer introduced, NO recurrent memory caching or sparse attention required.
  • Train Short, Test Long: NO performance drop on the sequences 4x~8x longer than the training ones (see here).
  • Continuous Length Extrapolation: Explicitly modeling the continuous dynamics of context window size during length extrapolation.

If you have any questions, feel free to contact us. (Emails: [email protected], [email protected])

Model Zoo

Model Name Model Type Starting Point Train Data Train Length MAX Test Length HF Repo
CLEX-LLaMA-2-7B-16K base LLaMA-2-7B Redpajama-Book 16K 64K link
CLEX-LLaMA-2-7B-Chat-16K chat CLEX-7B-16K UltraChat 16K 64K link
CLEX-LLaMA-2-7B-64K base LLaMA-2-7B Redpajama-Book 64k 256K link
CLEX-Phi-2-32K base Phi-2-2.7B LongCorpus-2.5B 32k 128K link
CLEX-Mixtral-8x7B-32K base Mixtral-8x7B-v0.1 LongCorpus-2.5B 32k >128K link
CLEX-Mixtral-8x7B-Chat-32k chat CLEX-Mixtral-8x7B-32K Ultrachat 200k 32k >128K link

Supported LLMs

  • LLaMA-2
  • Phi-2
  • Mixtral-8x7B
  • Mistral
  • Falcon
  • GPT-NeoX
  • QWen

Usage

Environment Setup

conda create -yn clex python=3.9
conda activate clex

git clone https://github.com/DAMO-NLP-SG/CLEX.git
cd CLEX
pip install -r requirements.txt
# install flash-attn separately
pip install flash-attn --no-build-isolation

Code Snippet for Minimal Usage

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/CLEX-7B-16K", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
  "DAMO-NLP-SG/CLEX-7B-16K",
  torch_dtype=torch.bfloat16,
  trust_remote_code=True,
  use_flash_attention_2=True
)
inputs = tokenizer("What is CLEX?", return_tensors="pt")
sample = model.generate(**inputs, max_length=128)
print(tokenizer.decode(sample[0]))

Inference with Command Line Interface

We replicate the command line interface of FastChat here. You can use the command below to enable the streaming chatting upon CLEX. The CLEX-7B-Chat-4K supports the input sequence lengths up to 16k.

python3 serve/cli.py --model-path DAMO-NLP-SG/CLEX-7B-Chat-4K --num-gpu 1

You can also try our web GUI demo here.

LongCorpus-2.5B

We collect a 2.5B training dataset from various domains for long-context continual pre-training. The composition of this dataset is as follows (partially inspired by Long-Data-Collection):

Domain Proportion Source
Book 40% Redpajama-Book
Arxiv 20% Redpajama-Arxiv
General 20% Redpajama
Code 10% LCC-Python
QA 5% Natural Questions
Summarization 5% BookSum

We have also curated a test dataset comprising 250 million tokens, mirroring the same composition. The selection criteria ensured that the average n-gram similarity (for n=2, 3, 4) with the training set is below 10%. This threshold effectively excludes all QA and Summarization data, resulting in a test corpus where the distribution of tokens across Book, Arxiv, General, and Code categories follows a ratio of 4:2:2:1, respectively.

Training

To train the long-context LLM with CLEX, run the script scripts/train_lm.sh as follows:

./scripts/train_lm.sh

For training the chat model, run the script scripts/train_chat.sh instead.

Note that we use an on-the-fly tokenization, which supports any desired training length without pre-tokenizing. So if you use a learning rate scheduler (e.g., cosine), you may need to specify the arg max_steps in the training arguments (You can estimate it depending on training data size).

Customization

We now support LLaMA-2, Phi-2, and Mixtral-8x7B. If you want to customize your LLM equipped with RoPE, please follow three steps:

  1. Init the CLEX layer and acquire the packed cos and sin embeddings of CLEX-scaled RoPE.
  2. Pass the cos and sin embeddings to the attention layer.
  3. Move the update of past_key_value before applying the RoPE. This ensures all keys would be rotated by the same cos and sin embeddings.

Evaluation

Language Modelling

Here are the evaluation PPLs of the base models trained with CLEX. We apply training and evaluation on a subset of 2B tokens from the RedPajama-Book corpus, where the training and test sets are split by 99:1.

Models Train Length Eval.(4k) Eval.(8k) Eval.(16k) Eval.(32k) Eval.(64k)
LLaMA-2-7B 4k 6.04 20.54 >100 >1000 >1000
CodeLLaMA-7B 16k 7.6 7.4 7.33 15.12 52.02
Naive FT 16k 5.98 5.93 5.91 18.31 > 100
PI 16k 5.9 5.71 5.72 6.05 8.75
Yarn (s=16) 16k 6.5 5.71 5.73 5.99 8.51
Yarn (s=32) 16k 6.61 5.94 5.96 6.08 6.22
CL-Scaling 16k 24.99 5.86 5.87 10.56 41.09
ALIBI 4k 6.34 6.39 6.41 6.5 6.51
RandomPos 4k 5.88 >100 >1000 >1000 >1000
CLEX-LLaMA-2-7B-4K 4k 5.86 5.7 5.87 14.53 30.51
CLEX-LLaMA-2-7B-16K 16k 5.88 5.68 5.52 5.55 5.64
CLEX-LLaMA-2-13B-4k 4k 5.43 5.31 5.34 6.40 12.15
Train Length Eval.(32k) Eval.(64k) Eval.(128k) Eval.(256k)
CLEX-LLaMA-2-7B 64k 5.99 5.89 6.04 5.98

The CLEX-Phi-2-2.7B and CLEX-Mixtral-8x7B are trained on LongCorpus-2.5B, where the eval results on test set are listed below.

Train Length Eval.(32k) Eval.(64k) Eval.(128k) Eval.(256k)
Phi-2-2.7B 2k >100 >100 >100 >100
CLEX-Phi-2-2.7B 32k 5.11 5.17 6.55 -
Mixtral-8x7B 32k 2.78 3.44 5.88 14.20
CLEX-Mixtral-8x7B 32k 2.56 2.53 2.57 3.78

LongBench

We evaluate the chat models trained with CLEX on the LongBench, where the average length of most tasks ranges from 5k to 16k. Except for those marked with † are evaluated by ourselves, the baseline results are retrieved from the leaderboard of LongBench. ** denotes the method that needs to truncate the input sequence to the train length.

Model Train Length Avg. Single-Document QA Multi-Document QA Summarization Few-shot Learning Sythetic Task Code Completion
GPT-3.5-Turbo-16K - 44.66 45.1 36.23 23.9 57.58 51 54.15
CodeLLaMA-7B 16k 33.42 32.19 21.49 20.06 57.73 8.92 60.11
Vicuna-v1.5-7B 16k 30.54 31.75 18.8 23.25 56.83 5.33 47.25
LongChat-v1.5-7B 32k 31.59 28.78 20.33 22.45 50.8 13.03 54.15
XGen-7B** 8k 24.96 22.15 18.02 19.05 47.23 4.7 38.6
InternLM-7B** 8k 22.64 21.45 17.9 15.2 41.55 3.3 36.45
Llama2-7B-chat** 4k 26.76 21.65 18.2 18.53 49.95 4.13 48.1
Baichuan-13B (ALiBi) 4k 13.49 18.36 6.79 9.93 11.72 1.85 32.28
ALiBi-7B-4K 4k 9.93 7.23 5.98 7.4 5.69 0.67 32.61
CLEX-7B-Chat-4K 4k 32.72 29.38 20.08 23.25 56.02 9.67 57.94

InfiniteBench

We also evaluate CLEX-Mixtral-8x7B-Chat-32k on InfiniteBench, which is a 128k-length benchmark covering various tasks. We compare our CLEX-Mixtral-8x7B-Chat-32k with GPT-4, Claude, KimiChat, and vanilla Mixtral-8x7B.

Task Name GPT-4 YaRN-Mistral-7B Kimi-Chat Claude 2 CLEX-Mixtral-8x7B-Chat-32k Mixtral-8x7B-Instruct-v0.1
Retrieve.PassKey 100% 92.71% 98.14% 97.80% 99.72% 96.78%
Retrieve.Number 100% 56.61% 95.42% 98.14% 76.10% 76.61%
Retrieve.KV 89.00% < 5% 53.60% 65.40% <5% <5%
En.Sum 14.73% 9.09% 17.93% 14.45% 15.48% 14.3%
En.QA 22.22% 9.55% 16.52% 11.97% 15.52% 16.81%
En.MC 67.25% 27.95% 72.49% 62.88% 58.96% 56.77%
En.Dia 8.50% 7.50% 11.50% 46.50% 9% <5%
Code.Debug 39.59% < 5% 18.02% < 5% 21.32% <5%
Code.Run 23.25% < 5% < 5% < 5% < 5% <5%
Math.Calc < 5% < 5% < 5% < 5% < 5% <5%
Math.Find 60.00% 17.14% 12.57% 32.29% 28% 26.57%

Key points:

  • We found Mixtral-8x7B-Instruct-v0.1 has some extrapolation ability, by setting the rope_theta as 1e6 following CodeLLaMA.
  • Our CLEX-Mixtral-8x7B-Chat-32k is also trained on 32k but perform better than vanilla Mixtral-8x7B-Instruct-v0.1 on most tasks.
  • Note that we only apply a "toy" SFT on Ultrachat 200K for one epoch, so the many bad cases of our model may be caused by the unsolid instuction-following ability (verbose or incomplete answers). The performance may hold great potential to be improved.

Acknowledgement

We would like to express our gratitude to the following open-sourcing efforts our CLEX benefits from:

  • LLaMA-2: Open Foundation and Fine-Tuned Chat Models
  • FastChat: An Open Platform for Training, Serving, and Evaluating Large Language Models.
  • RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset
  • Pile: An 800GB Dataset of Diverse Text for Language Modeling
  • PG-19: Language Modeling Language Modeling Benchmark
  • UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data, and Models
  • InfiniteBench: 100k+ Long-Context Benchmark for Large Language Models

Citation

If you find our project useful, hope you can star our repo and cite our paper as follows:

@article{damonlpsg2023clex,
  author = {Chen, Guanzheng and Li, Xin and Meng, Zaiqiao and Liang, Shangsong and Bing, Lidong},
  title = {CLEX: Continuous Length Extrapolation for Large Language Models},
  year = 2023,
  journal = {arXiv preprint arXiv:2310.16450},
  url = {https://arxiv.org/abs/2310.16450}
}

clex's People

Contributors

guanzhchen avatar lixin4ever avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

clex's Issues

Potential bug of logn_scale

Hi, I just found that there is a potential bug for the logn implementation in this repo. As shown in line, the scale factor is math.log(k_len) / math.log(train_len) for each q. However, in Su's blog, it should be torch.arange(k_len).log() / math.log(train_len). Its implementation can also be found at ReRoPE

RedPajama-Book 2B subset?

Hello! Could you tell me where can I get access to the "2B token subset of RedPajama-Book" dataset if I want to use it for training? I didn't find a corresponding description or link on the official website of RedPajama-Book. Thanks!

Questions about the code

Hi, I have several questions about the implementation.

  1. For scaled_inv_freq during the validation, as I understand, it should be scale_inv_freq = self.freq_cached[int(t_val)]. It does't need to subtract 1.
  2. In L97, if seq_len < self.max_position_embeddings, scale_factor would be zero so that L104 would encounter divide-zero error. It seems that // should be replaced with / for L97 and L105.
  3. In ODELinear
    • In L31, why assign alpha = 2 * t - 1 other than t?
    • In L35, the calculation of delta_ntk_freq was not found at the paper, which is $$-\frac{2i}{d-2} \cdot 10000^{\frac i d} \cdot \alpha^{\frac{i}{d-2}+1}$$
    • In L40-41, why x plus torch.log(time) and time_embed = delta_time / time? I feel a bit confused when comparing it with paper's Eq. (14).

Parameters for training

Could you please tell me what the parameters for training each model in train_lm.sh are? Thank you!

Unable to load model

Hi,

I am getting following error when trying to load the model using AutoModelFromCausalLM
Traceback (most recent call last):
File "", line 1, in
File "/opt/conda/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 526, in from_pretrained
config, kwargs = AutoConfig.from_pretrained(
File "/opt/conda/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 1099, in from_pretrained
return config_class.from_dict(config_dict, **unused_kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 774, in from_dict
config = cls(**config_dict)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/configuration_llama.py", line 160, in init
self._rope_scaling_validation()
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/configuration_llama.py", line 180, in _rope_scaling_validation
raise ValueError(
ValueError: rope_scaling must be a dictionary with with two fields, type and factor, got {'max_factor': 16, 'param_factor': 1, 'type': 'clex', 'factor': 1}

and when trying to load it via PhiForCausalLM, I got error during generate
File "/opt/conda/envs/clex/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/clex/lib/python3.10/site-packages/flash_attn/bert_padding.py", line 17, in forward
return torch.gather(
RuntimeError: index 17 is out of bounds for dimension 0 with size 7

Can you please guide me on how to set this up properly?

Slow training speed

Hi, I found that the forward and backward passes of odeint is very slow. It is probably caused by too much iterations during solving the Neural ODE. The backward process is similar to RNN's BPTT. Have you test the training latency in your experiments? How is it compared to the baselines settings, such as PI and Yarn.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.