Coder Social home page Coder Social logo

llama_tw's Introduction

Llama_tw

Prepare data

While we cannot release our data due to our policy, we do provide our data preprocessing script. You can follow these steps:

  1. Execute the data/split_data.py script to obtain a subset from the original dataset. Here's an example of splitting a 1B subset from the Traditional Chinese Dataset.
python3 split_data.py --dataset_path pretrain_cht_test.jsonl --output_path pretrain_cht_test_1B.jsonl --split_length 1000000000
  1. Run the data/preprocess_dataset.py script to convert the JSONL file into a HuggingFace dataset."
python3 preprocess_dataset.py --dataset_path pretrain_cht_test_1B.jsonl --output_path pretrain_cht_test_1B

Continual Pretraining

For our continual pretraining, we make use of DeepSpeed Integration with the HuggingFace Trainer. Here is the tutorial.

  • Claimer: We do not employ flash attention in our continual pretraining due to the tensor core and instruction issues with the V100.

Please begin by checking the configs/pretrain/llama_2.py file. After reviewing the configuration files, you can pretrain your own llama model by executing the llama_pretrain.py script.

Default

Here is an example of pretraining Llama-2-7b-chat on the 1B Traditional Chinese Dataset.

python -m torch.distributed.run --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES \
--node_rank $SLURM_PROCID --master_addr $MASTER_ADDR --master_port $MASTER_PORT llama_pretrain.py \ 
--model_name meta-llama/Llama-2-7b-chat-hf --dataset_path ./data/pretrain_cht_test_1B \
--run_name llama-2-7b-chat-zh1B --output_dir ./results/llama-2-7b-chat-zh1B 

Freeze layers

Here is an example of pretraining Llama-2-7b-chat on the 1B Traditional Chinese Dataset while freezing the first 10 layers of Llama-2-7b-chat.

python -m torch.distributed.run --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES \
--node_rank $SLURM_PROCID --master_addr $MASTER_ADDR --master_port $MASTER_PORT llama_pretrain.py \
--model_name meta-llama/Llama-2-7b-chat-hf --dataset_path ./data/pretrain_cht_test_1B \
--freeze_layers 0 1 2 3 4 5 6 7 8 9 \
--run_name llama-2-7b-chat-zh1B-freeze-first-10 --output_dir ./results/llama-2-7b-chat-zh1B-freeze-first-10

Freeze modules

To freeze the weights of specific modules, you should include additional code in the llama_pretrain.py script. Here is an example of pretraining Llama-2-7b-chat on the 1B Traditional Chinese Dataset while freezing mlp modules

First, add the following code to the llama_pretrain.py script.

for idx in range(len(model.model.layers)):
    for param in model.model.layers[idx].mlp.parameters():
    	param.requires_grad = False

After making these changes, proceed to execute the llama_pretrain.py script.

python -m torch.distributed.run --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES \
--node_rank $SLURM_PROCID --master_addr $MASTER_ADDR --master_port $MASTER_PORT llama_pretrain.py \
--model_name meta-llama/Llama-2-7b-chat-hf --dataset_path ./data/pretrain_cht_test_1B \
--run_name llama-2-7b-chat-zh1B-freeze-mlp --output_dir ./results/llama-2-7b-chat-zh1B-freeze-mlp 

Adapter

We use PEFT to implement continual pretraining for Llama with adapters. You can choose to pretrain your Llama model with LORA, and IA3. For further information, please check the configs/pretrain/peft.py file. During continual pretraining with an adapter, only the adapter weights will be saved. The model weights will not be saved. Here is an example of pretraining Llama-2-7b-chat with LORA adapter on the 1B Traditional Chinese Dataset.

python -m torch.distributed.run --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES \
--node_rank $SLURM_PROCID --master_addr $MASTER_ADDR --master_port $MASTER_PORT llama_pretrain.py \
--model_name meta-llama/Llama-2-7b-chat-hf --dataset_path ./data/pretrain_cht_test_1B \
--use_peft=True --peft_method LORA \
--run_name llama-2-7b-chat-zh1B-lora --output_dir ./results/llama-2-7b-chat-zh1B-lora

Caution

When conducting continual pretraining, the add_eos_token parameter in our tokenizer will be set to True. If you are performing inference on a model checkpoint in the middle of the training process, please ensure that you check your tokenizer_config.json file and set add_eos_token to False during inference.

Inference

We also provide the inference.py script for executing inferences with our model. We utilize vLLM to improve inference speed. You can customize the prompts inside the script according to your requirements. Here is an example of using the inference.py script.

python3 inference.py --model_name llama-2-7b-chat-zh1B \
--max_tokens 512 --temperature 0.1 --top_p 0.9 --tensor_parallel_size 8 --seed 42

When working with the PEFT model and conducting inference using vLLM, it is necessary to first merge the model with the adapter weights. Here is an example of merging LORA weights with Llama-2-7b-chat-hf.

python3 merge_model.py --model_name meta-llama/Llama-2-7b-chat-hf \
--peft_model_path ./results/llama-2-7b-chat-zh1B-lora/last \
--merged_model_path llama-2-7b-chat-zh1B-lora-merged

Analysis

For detailed instructions, please refer to each README.md file in the analysis folder.

References

We list some repositories that we have referenced.

Contact

If you have any questions, please do not hesitate to contact us at [email protected]

llama_tw's People

Contributors

lca0503 avatar

Stargazers

Victor Chen avatar Chun-Yi Kuan avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

chenbingxiayu

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.