Coder Social home page Coder Social logo

cos598d_pruning's Introduction

Network Pruning

Assignment 1 for COS598D: System and Machine Learning

In this assignment, you are required to evaluate three advanced neural network pruning methods, including SNIP [1], GraSP [2] and SynFlow [3], and compare with two baseline pruning methods, including random pruning and magnitude-based pruning. In example/singleshot.py, we provide an example to do singleshot global pruning without iterative training. In example/multishot.py, we provide an example to do multi-shot iterative training. This assignment focuses on the pruning protocol in example/singleshot.py. Your are going to explore various pruning methods on different hyperparameters and network architectures.

References

[1] Lee, N., Ajanthan, T. and Torr, P.H., 2018. Snip: Single-shot network pruning based on connection sensitivity. arXiv preprint arXiv:1810.02340.

[2] Wang, C., Zhang, G. and Grosse, R., 2020. Picking winning tickets before training by preserving gradient flow. arXiv preprint arXiv:2002.07376.

[3] Tanaka, H., Kunin, D., Yamins, D.L. and Ganguli, S., 2020. Pruning neural networks without any data by iteratively conserving synaptic flow. arXiv preprint arXiv:2006.05467.

Additional reading materials:

A recent paper [4] assessed [1-3].

[4] Frankle, J., Dziugaite, G.K., Roy, D.M. and Carbin, M., 2020. Pruning Neural Networks at Initialization: Why are We Missing the Mark?. arXiv preprint arXiv:2009.08576.

Getting Started

First clone this repo, then install all dependencies.

pip3 install -r requirements.txt

(2024 Spring update from Rui: I have verified that the code in this assignment works for Python v3.10.12, torch v2.1.2+cu121, and torchvision v0.16.2+cu121. You might need to use different versions of these packages based on your environment. For any issues, make a post on Ed and I can help take a look.)

How to Run

Run python main.py --help for a complete description of flags and hyperparameters. You can also go to main.py to check all the parameters.

Example: Initialize a VGG16, prune with SynFlow and train it to the sparsity of 10^-0.5 . We have sparsity = 10**(-float(args.compression)).

python3 main.py --model-class lottery --model vgg16 --dataset cifar10 --experiment singleshot --pruner synflow --compression 0.5

To save the experiment, please add --expid {NAME}. --compression-list and --pruner-list are not available for runing singleshot experiment. You can modify the souce code following example/multishot.py to run a list of parameters. --prune-epochs is also not available as it does not affect your pruning in singleshot setting.

For magnitude-based pruning, please set --pre-epochs 200. You can reduce the epochs for pretrain to save some time. The other methods do pruning before training, thus they can use the default setting --pre-epochs 0.

Please use the default batch size, learning rate, optimizer in the following experiment. Please use the default training and testing spliting. Please monitor training loss and testing loss, and set suitable training epochs. You may try --post-epoch 100 for Cifar10 and --post-epoch 10 for MNIST.

If you are using Google Colab, to accommodate the limited resources on Google Colab, you could use --pre-epochs 10 for magnitude pruning and use --post-epoch 10 for cifar10 for experiments on Colab. And state the epoch numbers you set in your report.

You Tasks

1. Hyper-parameter tuning

Testing on different archietectures. Please fill the results table:

Test accuracy (top 1) of pruned models on CIFAR10 and MNIST (sparsity = 10%). --compression 1 means sparsity = 10^-1.

python main.py --model-class lottery --model vgg16 --dataset cifar10 --experiment singleshot --pruner synflow --compression 1
python main.py --model-class default --model fc --dataset cifar10 --experiment singleshot --pruner synflow --compression 1

Testing accuracy (top 1)

Data Arch Rand Mag SNIP GraSP SynFlow
Cifar10 VGG16
MNIST FC

Tuning compression ratio. Please fill the results table:

Prune models on CIFAR10 with VGG16, please replace {} with sparsity 10^-a for a \in {0.05,0.1,0.2,0.5,1,2}. Feel free to try other sparsity values.

python main.py --model-class lottery --model vgg16 --dataset cifar10 --experiment singleshot --pruner synflow  --compression {}

Testing accuracy (top 1)

Compression Rand Mag SNIP GraSP SynFlow
0.05
0.1
0.2
0.5
1
2

Testing time (inference on testing dataset)

Compression Rand Mag SNIP GraSP SynFlow
0.05
0.1
0.2
0.5
1
2

To track the runing time, you can use timeit. pip intall timeit if it has not been installed.

import timeit

start = timeit.default_timer()

#The module that you try to calculate the running time

stop = timeit.default_timer()

print('Time: ', stop - start)

FLOP

Compression Rand Mag SNIP GraSP SynFlow
0.05
0.1
0.2
0.5
1
2

For better visualization, you are encouraged to transfer the above three tables into curves and present them as three figrues.

2. The compression ratio of each layer

Report the sparsity of each layer and draw the weight histograms of each layer using pruner Rand | Mag | SNIP | GraSP | SynFlow with the following settings model = vgg16, dataset=cifar10, compression = 0.5

Weight histogram is a figure showing the distribution of weight values. Its x axis is the value of each weight, y axis is the count of that value in the layer. Since the weights are floating points, you need to partite the weight values into multiple intervals and get the numbers of weights which fall into each interval. The weight histograms of all layers of one pruning method can be plotted in one figure (one histogram for each layer).

This is an example of weight histograms for NN https://stackoverflow.com/questions/42315202/understanding-tensorboard-weight-histograms

Bonus (optional)

Report the FLOP of each layer using pruner Rand | Mag | SNIP | GraSP | SynFlow with the following settings model = vgg16, dataset=cifar10, compression= 0.5.

3. Explain your results and submit a short report.

Please describe the settings of your experiments. Please include the required results (described in Task 1 and 2). Please add captions to describe your figures and tables. It would be best to write brief discussions on your results, such as the patterns (what and why), conclusions, and any observations you want to discuss.

cos598d_pruning's People

Contributors

xxlya avatar yushansu avatar ruipeterpan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.