Coder Social home page Coder Social logo

nn-compression's Introduction

nn-compression

A Pytorch implementation of Neural Network Compression (pruning, quantization, encoding/decoding)

Most work of this repo is better done in distiller. However, they have not implement channel pruning and coding yet. With coding in this repo, you can save the model with actually much smaller memory size.

Pruning

Neural Network Pruning reduces the number of nonzero parameters and thus computation amount (FLOPs).

Vanilla Pruning

Deep Compression uses vanilla pruning method. It prunes the parameters with the least importance.

  • Elementwise Pruning: prune those with the smallest magnitude

  • Kernelwise Pruning: prune 2D kernels with the smallest L1(default)/L2 norm

  • Filterwise Pruning: prune 3D filters with the smallest L1(default)/L2 norm

# vanilla pruner usage

from modules.prune import VanillaPruner

rule = [
        ('0.weight', 'element', [0.3, 0.5], 'abs'),
        ('1.weight', 'kernel', [0.4, 0.6], 'default')
        ('2.weight', 'filter', [0.5, 0.7], 'l2norm')
    ]

pruner = VanillaPruner(rule=rule)
"""
:param rule: str, path to the rule file, each line formats
                  'param_name granularity sparsity_stage_0, sparstiy_stage_1, ...'
             list of tuple, [(param_name(str), granularity(str),
                              sparsity(float) or [sparsity_stage_0(float), sparstiy_stage_1,],
                              fn_importance(optional, str or function))]
             'granularity': str, choose from ['element', 'kernel', 'filter']
             'fn_importance': str, choose from ['abs', 'l1norm', 'l2norm', 'default']
"""

stage = 0

for epoch in range(0, 90):
    if epoch == 0:
        pruner.prune(model=model, stage=stage, update_masks=True)
        best_prec1 = validate(val_loader, model, criterion, epoch)
    
    # in train function
    for i, (input, target) in enumerate(train_loader):
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pruner.prune(model=model, stage=stage, update_masks=False)

Channel Pruning

Channel Pruning is another set of neural network pruning methods. It reduces the number of output channels in every convolution or fully-connected layers. Therefore, it can directly speed up the inference.

Channel Pruning takes 2 steps:

  1. Channel Selection: select channels with least impact to prune
  2. Parameter Reconstruction: reconstruct the parameter values to optimize the output feature of the next layer to the pruned one

These two steps are conducted layer by layer.

# channel pruning usage

def prune_channel(sparsity, module, next_module, fn_next_input_feature, input_feature,
                  method='greedy', cpu=True):
    """
    channel pruning core function
    :param sparsity: float, pruning sparsity
    :param module: torch.nn.module, module of the layer being pruned
    :param next_module: torch.nn.module, module of the next layer to the one being pruned
    :param fn_next_input_feature: function, function to calculate the input feature map for next_module
    :param input_feature: torch.(cuda.)Tensor, input feature map of the layer being pruned
    :param method: str
        'greedy': select one contributed to the smallest next feature after another
        'lasso': pruned channels by lasso regression
        'random': randomly select
    :param cpu: bool, whether done in cpu for larger reconstruction batch size
    :return:
        void
    """

Detailed example shows in here.

Quantization

Neural Network Quantization is to represent the parameters with fewer bits.

Vanilla Quantization

There are several ways to quantize neural network parameters:

  • Fixed-point Quantization: the most common way, uses (i+f)-bits to represent the number, where i-bits for integer and f-bits for fraction.

  • Uniform/Linear Quantization: quantization centroids lies uniformly in the range of parameter values, i.e., the quantization step equals $(max - min) / k$, where k is the quantization levels

  • K-Means Quantization: quantization centroids calculated by K-Means clustering

# vanilla quantizer usage

from modules.quantize import Quantizer

rule = [
        ('0.weight', 'k-means', 4, 'k-means++'),
        ('1.weight', 'fixed_point', 6, 1),
    ]

quantizer = Quantizer(rule=rule, fix_zeros=True)
"""
:param rule: str, path to the rule file, each line formats
                'param_name method bit_length initial_guess_or_bit_length_of_integer'
             list of tuple,
                [(param_name(str), method(str), bit_length(int),
                  initial_guess(str)_or_bit_length_of_integer(int))]
:param fix_zeros: whether to fix zeros when quantizing
"""

for epoch in range(0, 90):
    # in the train loop
    
    # in train function
    for i, (input, target) in enumerate(train_loader):
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        quantizer.quantize(model=model, update_labels=True, re_quantize=False)
        """
        :param update_labels: bool, whether to re-allocate the param elements
                                    to the latest centroids when using k-means
        :param re_quantize: bool, whether to re-quantize the param when using k-means
        """

Coding

Coding is the last step to compress the neural network in Deep Compression:

  • Fixed-point Coding: it actually is not a coding method, just in case if we want to actually save the model in fixed-point style.

  • Vanilla (Linear) Coding: it uses $log_2 (N)$-bits to represent N float number in the codebook, i.e., there are only N possible values in a parameter matrix

  • Huffman Coding: it uses huffman coding to represent N float number in the codebook

# coding codec usage (encode)

import torch
from modules.coding import Codec

rule = [
        ('0.weight', 'huffman', 0, 0, 4),
        ('1.weight', 'fixed_point', 6, 1, 4)
    ]

codec = Codec(rule=rule)
"""
:param rule: str, path to the rule file, each line formats
                'param_name coding_method bit_length_fixed_point bit_length_fixed_point_of_integer_part
                 bit_length_of_zero_run_length'
             list of tuple,
                [(param_name(str), coding_method(str), bit_length_fixed_point(int),
                 bit_length_fixed_point_of_integer_part(int), bit_length_of_zero_run_length(int))]
"""

encoded_model = codec.encode(model=model)

torch.save({'state_dict': encoded_model.state_dict()}, 'encode.pth.tar', pickle_protocol=4)
# coding codec usage (decode)

import torch
from modules.coding import Codec

checkpoint = torch.load('encode.pth.tar')

model = Codec.decode(model=model, state_dict=checkpoint['state_dict'])  # initial model is created before

torch.save({'state_dict': model.state_dict()}, 'decode.pth.tar')

Rerference

@article{han2015deep,
  title={Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding},
  author={Han, Song and Mao, Huizi and Dally, William J},
  journal={arXiv preprint arXiv:1510.00149},
  year={2015}
}
@inproceedings{han2015learning,
  title={Learning both weights and connections for efficient neural network},
  author={Han, Song and Pool, Jeff and Tran, John and Dally, William},
  booktitle={Advances in neural information processing systems},
  pages={1135--1143},
  year={2015}
}
@article{luo2017thinet,
  title={Thinet: A filter level pruning method for deep neural network compression},
  author={Luo, Jian-Hao and Wu, Jianxin and Lin, Weiyao},
  journal={arXiv preprint arXiv:1707.06342},
  year={2017}
}
@inproceedings{he2017channel,
  title={Channel pruning for accelerating very deep neural networks},
  author={He, Yihui and Zhang, Xiangyu and Sun, Jian},
  booktitle={International Conference on Computer Vision (ICCV)},
  volume={2},
  number={6},
  year={2017}
}

nn-compression's People

Contributors

synxlin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

nn-compression's Issues

I think the function channel_selection have some miss

Hi, I think the function channel_selection in slender.prune.channel is constructed with some miss.
In the paper, they use the next_input_features and next_output_features when channels are selected (paper said X is after Relu and Y is before Relu).
It means the next_input_features should be divided along the channels.
But, in your code output_features are divided before the nonlinear and pass it, so it is different to the paper description.
Is it your missing? or Intended?

How to repetition your Deep Gradient Compression?

Hello, I am learning about gradient compression. Luckily, I read your paper. But now, I am confused about how to achieve the distributed training on Pytorch using sparse tensor. Do you use sparse tensor in your paper project and what's a framework you use in your experiment?

Pruning Resnets

Could you please tell how can we prune resnets using nn-compression.
what goes into next_module,etc for let say resnet-18

3d convolution support

Hi. I'm currently looking at your code. It doesn't produce a bug on 3d convolutions but I wondered if it would eventually work (one epoch with 3d convolution context is usually pretty long ! ).

Have you ever gotten any feedback concerning these ? Do you think there's no reason it shouldn't ? Especially for quantization.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.