Coder Social home page Coder Social logo

wanda's Introduction

Pruning LLMs by Weights and Activations

Official PyTorch implementation of Wanda (Pruning by Weights and activations), as presented in our paper:

A Simple and Effective Pruning Approach for Large Language Models
Mingjie Sun*, Zhuang Liu*, Anna Bair, J. Zico Kolter (* indicates equal contribution)
Carnegie Mellon University, Meta AI Research and Bosch Center for AI
Paper - Project page

@article{sun2023wanda,
  title={A Simple and Effective Pruning Approach for Large Language Models}, 
  author={Sun, Mingjie and Liu, Zhuang and Bair, Anna and Kolter, J. Zico},
  year={2023},
  journal={arXiv preprint arXiv:2306.11695}
}

Compared to magnitude pruning which removes weights solely based on their magnitudes, our pruning approach Wanda removes weights on a per-output basis, by the product of weight magnitudes and input activation norms.

Update

  • (9.22.2023) Add support for LLaMA-2.
  • (9.22.2023) Add code to reproduce the ablation study on OBS weight update in the paper.
  • (10.6.2023) Add new support for the weight update analysis in the ablation study. Feel free to try it out!
  • (10.6.2023) Add support for zero-shot evaluation.
  • (10.20.2023) Add code for pruning OPT models.
  • (10.23.2023) Add code for LoRA fine-tuning.

Setup

Installation instructions can be found in INSTALL.md.

Usage

The scripts directory contains all the bash commands to replicate the main results (Table 2) in our paper.

Below is an example command for pruning LLaMA-7B with Wanda, to achieve unstructured 50% sparsity.

python main.py \
    --model decapoda-research/llama-7b-hf \
    --prune_method wanda \
    --sparsity_ratio 0.5 \
    --sparsity_type unstructured \
    --save out/llama_7b/unstructured/wanda/ 

We provide a quick overview of the arguments:

  • --model: The identifier for the LLaMA model on the Hugging Face model hub.
  • --cache_dir: Directory for loading or storing LLM weights. The default is llm_weights.
  • --prune_method: We have implemented three pruning methods, namely [magnitude, wanda, sparsegpt].
  • --sparsity_ratio: Denotes the percentage of weights to be pruned.
  • --sparsity_type: Specifies the type of sparsity [unstructured, 2:4, 4:8].
  • --use_variant: Whether to use the Wanda variant, default is False.
  • --save: Specifies the directory where the result will be stored.

For structured N:M sparsity, set the argument --sparsity_type to "2:4" or "4:8". An illustrative command is provided below:

python main.py \
    --model decapoda-research/llama-7b-hf \
    --prune_method wanda \
    --sparsity_ratio 0.5 \
    --sparsity_type 2:4 \
    --save out/llama_7b/2-4/wanda/ 

Pruning LLaMA-2

For LLaMA-2 models, replace --model with meta-llama/Llama-2-7b-hf (take 7b as an example):

python main.py \
    --model meta-llama/Llama-2-7b-hf \
    --prune_method wanda \
    --sparsity_ratio 0.5 \
    --sparsity_type unstructured \
    --save out/llama2_7b/unstructured/wanda/

LLaMA-2 results: (LLaMA-2-34b is not released as of 9.22.2023)

sparsity ppl llama2-7b llama2-13b llama2-70b
- dense 5.12 4.57 3.12
unstructured 50% magnitude 14.89 6.37 4.98
unstructured 50% sparsegpt 6.51 5.63 3.98
unstructured 50% wanda 6.42 5.56 3.98
4:8 magnitude 16.48 6.76 5.58
4:8 sparsegpt 8.12 6.60 4.59
4:8 wanda 7.97 6.55 4.47
2:4 magnitude 54.59 8.33 6.33
2:4 sparsegpt 10.17 8.32 5.40
2:4 wanda 11.02 8.27 5.16

Ablation on OBS weight update

To reproduce the analysis on weight update, we provide our implementation for this ablation. All commands can be found in this script.

for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter 
do 
CUDA_VISIBLE_DEVICES=0 python main.py \
  --model decapoda-research/llama-7b-hf \
  --sparsity_ratio 0.5 \
  --sparsity_type unstructured \
  --prune_method ${method} \
  --save out/llama_7b_ablation/unstructured/
done 

Here ablate_{mag/wanda}_{seq/iter} means that we use magnitude pruning or wanda to obtain the pruned mask at each layer, then apply weight update procedure with either a sequential style or an iterative style every 128 input channels. For details, please see Section 5 of our paper.

Zero-Shot Evaluation

For evaluating zero-shot tasks, we modify the EleutherAI LM Harness framework so that it could evaluate pruned LLM models. We provide the modified repo in this link. Make sure to download, extract and install this custom lm_eval package from the source code.

For reproducibility, we used commit df3da98 on the main branch. All tasks were evaluated on task version of 0 except for BoolQ, where the task version is 1.

On a high level, the functionality we provide is adding two arguments pretrained_model and tokenizer in this function. We can then call this simple_evaluate function API from our codebase to evaluate sparse pruned LLMs. To evaluate zero-shot tasks in addition to the WikiText perplexity, pass the --eval_zero_shot argument.

Speedup Evaluation

The pruning speed for each method is evaluated by the cumulated time spent on pruning (for each layer), without the forward passes.

For inference speedup with structured sparsity, we refer the reader to this blog post, where structured sparsity is supported by PyTorch >= 2.1. You can switch between the CUTLASS or CuSPARSELt kernel here.

Last, for pruning image classifiers, see directory image_classifiers for details.

Acknowledgement

This repository is build upon the SparseGPT repository.

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Questions

Feel free to discuss papers/code with us through issues/emails!

mingjies at cs.cmu.edu
liuzhuangthu at gmail.com

wanda's People

Contributors

eric-mingjie avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.