Coder Social home page Coder Social logo

structured-bayesian-pruning-pytorch's Introduction

Structured-Bayesian-Pruning-pytorch

pytorch implementation of Structured Bayesian Pruning, NIPS17. Authors of this paper provided TensorFlow implementation. This implementation is built on pytorch 0.4.

Some preliminary results on MNIST:

Network Method Error Neurons per Layer
LeNet-5 Orginal 0.68 20 - 50 - 800 - 500
SBP 0.86 3 - 18 - 284 - 283
SBP* 1.17 8 - 15 - 163 - 81

SBP* denotes the results from my implementation, I believe the results can be improved by hyperparameter tuning.

As a byproduct of my implementation, I roughly plot the graph of average layerwise sparsity vs. the performance of the model in MNIST. Average layerwise sparsity is not an accurate approximation for the compression rate, but you can get an idea how they related in Structred Bayesian Pruning.

The code only contains experiment to reproduce MNIST experiment, the file is LeNet_MNIST.py, however, it can be easily expanded to any other models or dataset. Here I give a simple example on how to custom your own model with Structured Bayesian Pruning.

from SBP_utils import SBP_layer
import torch.nn as nn
import torch

batch = 3
input_dim = 5 
output_dim = 10

#for CNN layer, input_dim is number of channels; for linear layer, input_dim is number of neurons
linear = nn.Linear(input_dim,output_dim)
sbp_layer = SBP_layer(output_dim)

#perform forward pass
x = torch.randn(batch, input_dim)
x = linear(x)
y, kl = sbp_layer(x)

#don't forget add kl to loss
loss = loss + kl

structured-bayesian-pruning-pytorch's People

Contributors

gaosh avatar sean922 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

structured-bayesian-pruning-pytorch's Issues

KL Loss Weights

Thanks for porting tensorflow to pytorch, it's a great work.

https://github.com/necludov/group-sparsity-sbp/blob/master/nets/metrics.py#L49
The author does not set kl weight ( sum up the KL loss and use default weight 1 ) ( or I didn't find it out?)

https://github.com/gaosh/Structured-Bayesian-Pruning-pytorch/blob/master/LeNet.py#L111
you set kl loss weights as [0.3, 0.3, 0.2, 0.2], how do you set it?

I have been tried SSL before, the weights are similar to regularization parameter which could influence the final sparsity of model.

Pruning AlexNet

Hi, thank you for your great work!

I know its a bit of long shot, but I was wondering if you had any insights on a strange problem I come across when pruning alexNet.

Specifically, I'm trying to use this code to prune AlexNet. I'd tried a variety of learning rates, but invariably, the following happens: The training and testing accuracy is increasing, and the SNR is dropping drops towards 1. However, the layerwise sparsity remains 0 across all layers while the SNR > 1. Then, immediately after SNR < 1, the training accuracy immediately plummets to around ~1%, and does not recover. However, the training accuracy remains high.

I was wondering if you had an insights on why this may be happening. I'm waiting until sparsity (layerwise_sparsity) > 0.0 so I can see some pruning, but this comes at a huge, sudden accuracy loss. Am I using the wrong stopping criterion here, learning rate etc? -- Any insights on what could be going wrong would be deeply appreciated!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.