Coder Social home page Coder Social logo

lorenzkuhn / shrinkbench-1 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from jjgo/shrinkbench

0.0 0.0 0.0 350 KB

PyTorch library to facilitate development and standardized evaluation of neural network pruning methods.

License: MIT License

Python 53.81% Jupyter Notebook 46.19%

shrinkbench-1's Introduction

ShrinkBench

Open source PyTorch library to facilitate development and standardized evaluation of neural network pruning methods.

Paper

This repo contains the analysis and benchmarks results from the paper What is the State of Neural Network Pruning?.

Installation

First, install the dependencies, this repo depends on

  • PyTorch
  • Torchvision
  • NumPy
  • Pandas
  • Matplotlib

To install the dependencies

# Create a python virtualenv or conda env as necessary

# With conda
conda install numpy matplotlib pandas
conda install pytorch torchvision -c pytorch

# With pip
pip install numpy matplotlib pandas pytorch torchvision

then, to install the module itself you just need to clone the repo and add the parent path it to your PYTHONPATH. For example:

git clone [email protected]:JJGO/shrinkbench.git shrinkbench

# Bash
echo "export PYTHONPATH=\"$PWD:\$PYTHONPATH\"" >> ~/.bashrc

# ZSH
echo "export PYTHONPATH=\"$PWD:\$PYTHONPATH\"" >> ~/.zshrc

Strategies

ShrinkBench not only faciliates evaluation of pruning methods, but also their development. Here's the code for a simple implementation of Global Magnitude Pruning and Layerwise Magnitude Pruning. As you can see, it is quite succint; you are just tasked with implementing model_masks a function that returns the masks for the model's weight tensors. If you want to prune your model layerwise, then you just need to implement layer_masks. For more examples, see the source code for the provided baselines.

class GlobalMagWeight(VisionPruning):

    def model_masks(self):
        importances = map_importances(np.abs, self.params())
        flat_importances = flatten_importances(importances)
        threshold = fraction_threshold(flat_importances, self.fraction)
        masks = importance_masks(importances, threshold)
        return masks


class LayerMagWeight(LayerPruning, VisionPruning):

    def layer_masks(self, module):
        params = self.module_params(module)
        importances = {param: np.abs(value) for param, value in params.items()}
        masks = {param: fraction_mask(importances[param], self.fraction)
                 for param, value in params.items() if value is not None}
        return masks

Experiments

See here for a notebook showing how to run pruning experiments and plot their results

Modules

The modules are organized as follows:

submodule Description
analysis/ Aggregated survey results over 80 pruning papers
datasets/ Standardized dataloaders for supported datasets
experiment/ Main experiment class with the data loading, pruning, finetuning & evaluation
metrics/ Utils for measuring accuracy, model size, flops & memory footprint
models/ Custom architectures not included in torchvision
plot/ Utils for plotting across the logged dimensions
pruning/ General pruning and masking API.
scripts/ Executable scripts for running experiments (see experiment/)
strategies/ Baselines pruning methods, mainly magnitude pruning based

shrinkbench-1's People

Contributors

jjgo avatar dblalock avatar anon-paper-submissions-1982 avatar lorenz-openai avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.