Coder Social home page Coder Social logo

early_weight_avg's Introduction

Early Weight Averaging

Pre-train Large Language Models (LLMs) faster with Early Weight Averaging. For more details, refer to our paper: Early Weight Averaging meets High Learning Rates for LLM Pretraining.

Abstract

Training Large Language Models (LLMs) incurs significant cost, making strategies that accelerate model convergence highly valuable. In our research, we focus on the impact of checkpoint averaging along the trajectory of a training run to enhance both convergence and generalization early in the training process. We observe that models trained with high learning rates benefit more from checkpoint averaging. This effect is further intensified when checkpoints are sampled with substantial spacing in training steps. Our training method surpasses conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We demonstrate the efficacy of our approach by pre-training nanoGPT-2 models of various sizes—small (125M), medium (335M), and large (770M)—on the OpenWebText dataset, consisting of 9 billion tokens. We also present results for publicly available Pythia LLMs, ranging from 1 billion to 12 billion parameters, trained on the PILE-deduped dataset containing 207 billion tokens.

Data Preparation

Prepare the OpenWebText data following nanoGPT:

$ python data/openwebtext/prepare.py

Training Script for Small nanoGPT-2

Normal Training & EMA

To train a small nanoGPT-2 model (also runs an EMA variant), use the following command:

torchrun --standalone --nproc_per_node=3 train_ema_small.py

Similarly, for medium and large:

torchrun --standalone --nproc_per_node=3 train_ema_medium.py
torchrun --standalone --nproc_per_node=3 train_ema_large.py
SWA

To train To train a small nanoGPT-2 model with SWA, use the following command:

torchrun --standalone --nproc_per_node=3 train_swa.py

Loss curves for Normal training, EMA and SWA:

LAWA Checkpoint Averaging

Similarly, to run LAWA on already saved checkpoints:

torchrun --standalone --nproc_per_node=3 lawa.py

Please refer to the large and medium configurations of the EMA scripts for running large and medium versions SWA and LAWA.

Loss curves for LAWA:

Dependencies

  • pytorch 2.0
  • transformers
  • datasets
  • tiktoken
  • wandb

Cite

If you find this work helpful, please consider citing us:

@inproceedings{
sanyal2023early,
title={Early Weight Averaging meets High Learning Rates for {LLM} Pre-training},
author={Sunny Sanyal and Atula Neerkaje and Jean Kaddour and Abhishek Kumar and Sujay Sanghavi},
booktitle={Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@NeurIPS 2023)},
year={2023},
url={https://openreview.net/forum?id=I2aJVWHA93}
}

Acknowledgement

The training code is mainly adapted from Sophia and nanoGPT.

early_weight_avg's People

Contributors

sanyalsunny111 avatar

Stargazers

Hiroki Naganuma avatar  avatar Sang avatar Yingfei(Jeremy) Xiang avatar Eugene Siow avatar smellslikeml avatar Avinab Saha avatar Mohas avatar Faiq Adzlan avatar  avatar Hongjian avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.