Coder Social home page Coder Social logo

pytorch_bn_fusion's Introduction

Batch Norm Fusion for Pytorch

About

In this repository, we present a simplistic implementation of batchnorm fusion for the most popular CNN architectures in PyTorch. This package is aimed to speed up the inference at the test time: expected boost is 30%! In the future

How it works

We know that both - convolution and batchnorm are the linear operations to the data point x, and they can be written in terms of matrix multiplications: T_{bn}*S{bn}Conv_W(x), where we first apply convolution to the data, scale it and eventually shift it using the batchnorm-trained parameters.

Supported architectures

We support any architecture, where Conv and BN are combined in a Sequential module. If you want to optimize your own networks with this tool, just follow this design. For the conveniece, we wrapped VGG, ResNet and SeNet families to demonstrate how your models can be converted into such format.

  • VGG from torchvision.
  • ResNet Family from torchvision.
  • SeNet family from pretrainedmodels

How to use

import torchvision.models as models
from bn_fusion import fuse_bn_recursively

net = getattr(models,'vgg16_bn')(pretrained=True)
net = fuse_bn_recursively(net)
net.eval()
# Make inference with the converted model

TODO

  • Tests.
  • Performance benchmarks.

Acknowledgements

Thanks to @ZFTurbo for the idea, discussions and his implementation for Keras.

pytorch_bn_fusion's People

Contributors

lext avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch_bn_fusion's Issues

No effect after “ fuse_bn_recursively”

I load mobileNet v2 and operate it by fuse_bn_recursively function, then print the network strutures of this two model, but I found that the bn_fusion net is the same as the initial net, is it because of my misoperation?
`
import torch
from bn_fusion import fuse_bn_recursively
from pytorchcv.model_provider import get_model as ptcv_get_model

if name == 'main':

net = ptcv_get_model('mobilenetv2_w1', pretrained=True)

net1 = fuse_bn_recursively(net)
net1.eval()

net_dict1 = {}
for idx,(name,param) in enumerate(net.named_parameters()):
    net_dict1[name] = param

net_dict2 = {}
for idx,(name,param) in enumerate(net1.named_parameters()):
    net_dict2[name] = param
names = net_dict1.keys()

diff_cnt = 0
for name in names:
    if net_dict1[name].shape!=net_dict2[name].shape:
        diff_cnt +=1
print("diff params:",diff_cnt)

`

is this inference only or training as well?

First of all, thanks for sharing this code :)

Is this code expected to be used purely in inference or can it be used for training as well?

btw - my main motivation for using batchnorm fusion isn't run speed, it's reducing memory requirements.

Resnet3D

In my experiment in Resnet3D(50), the "convert_resnet_family" will speed up the inference, but "fuse_bn_recursively" not.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.