Coder Social home page Coder Social logo

lucasmazz / integer-only-inference-for-deep-learning-in-native-c Goto Github PK

View Code? Open in Web Editor NEW

This project forked from benja263/integer-only-inference-for-deep-learning-in-native-c

0.0 0.0 0.0 889 KB

Converting a deep neural network to integer-only inference in native C via uniform quantization and the fixed-point representation.

Python 5.69% C 94.31%

integer-only-inference-for-deep-learning-in-native-c's Introduction

Integer-Only Inference for Deep Learning in Native C

A repository containing Native C-code implementation of a convolutional neural network and multi-layer perceptron (MLP) models for integer-only inference. Model parameters are quantized to 8-bit integers, and floats are replaced with the fixed-point representation.
The repository contains:

  • scripts for training model with PyTorch
  • post-training quantization of model parameters to 8-bit integers,
  • writing the relevant parameters in C
  • interfacing the C code for integer-only inference via C-types.

The ideas presented in this tutorial were used to quantize and write an inference-only C code to deploy a deep reinforcement learning algorithm on a network interface card (NIC) in Tessler et al. 2021[1].

Requirements

Quantization is based on Nvidia's pytorch-quantization, which is part of TensorRT.
https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization
pytorch-quantization allows for more sophisticated quantization methods than what is presented here. For more details, see Hao et al. 2020[2].

NOTE pytorch-quantization requires a GPU and will not work without it

C-code

The c-code is structured to have separate files for the MLP and ConvNet models.
C-code is located within the src directory in which:

  • nn_math - source and header files contain relevant mathematical functions
  • nn - source and header files contain relevant layers to create the neural network models
  • mlp - source and header files contain the MLP architecture to run for inference
  • convnet - source and header files contain the ConvNet architecture to run for inference
  • mlp_params - source and header files are generated via scripts/create_mlp_c_params.py and contains network weights, scale factors, and other relevant constants for the MLP model
  • convnet_params - source and header files are generated via scripts/create_convnet_c_params.py and contains network weights, scale factors, and other relevant constants for the ConvNet model

Compilation

The repository was tested using gcc.
To compile and generate a shared library that can be called from Python using c-types run the following commands:

MLP

gcc -Wall -fPIC -c mlp_params.c mlp.c nn_math.c nn.c  
gcc -shared mlp_params.o mlp.o nn_math.o nn.o -o mlp.so

ConvNet

gcc -Wall -fPIC -c convnet_params.c convnet.c nn_math.c nn.c  
gcc -shared convnet_params.o convnet.o nn_math.o nn.o -o convnet.so

Scripts

  • src/train_mlp.py and src/train_convnet.py are used to train an MLP/ConvNet model using PyTorch
  • src/quantize_with_package.py is used to quantize the models using the pytorch-quantization package
  • src/create_mlp_c_params.py and src/create_convnet_c_params.py create the header and source C files with relevant constants (network parameters, scale factors, and more) required to run the C-code.
  • src/test_mlp_c.py and src/test_convnet_c.py run inference on the models using C-types to interface the C-code files from Python

Results - on the MNIST dataset

MLP

Training 
Epoch: 1 - train loss: 0.35650 validation loss: 0.20097
Epoch: 2 - train loss: 0.14854 validation loss: 0.13693
Epoch: 3 - train loss: 0.10302 validation loss: 0.11963
Epoch: 4 - train loss: 0.07892 validation loss: 0.11841
Epoch: 5 - train loss: 0.06072 validation loss: 0.09850
Epoch: 6 - train loss: 0.04874 validation loss: 0.09466
Epoch: 7 - train loss: 0.04126 validation loss: 0.09458
Epoch: 8 - train loss: 0.03457 validation loss: 0.10938
Epoch: 9 - train loss: 0.02713 validation loss: 0.09077
Epoch: 10 - train loss: 0.02135 validation loss: 0.09448
Evaluating model on test data
Accuracy: 97.450%
Evaluating integer-only C model on test data
Accuracy: 97.27%

ConvNet

Training
Epoch: 1 - train loss: 0.37127 validation loss: 0.12948
Epoch: 2 - train loss: 0.09653 validation loss: 0.08608
Epoch: 3 - train loss: 0.07089 validation loss: 0.07480
Epoch: 4 - train loss: 0.05846 validation loss: 0.06347
Epoch: 5 - train loss: 0.05044 validation loss: 0.05909
Epoch: 6 - train loss: 0.04567 validation loss: 0.05466
Epoch: 7 - train loss: 0.04071 validation loss: 0.05099
Epoch: 8 - train loss: 0.03668 validation loss: 0.05336
Epoch: 9 - train loss: 0.03543 validation loss: 0.04965
Epoch: 10 - train loss: 0.03164 validation loss: 0.04883
Evaluate model on test data
Accuracy: 98.620%
Evaluating integer-only C model on test data
Accuracy: 98.58%

References

[1] Tessler, C., Shpigelman, Y., Dalal, G., Mandelbaum, A., Kazakov, D. H., Fuhrer, B., Chechik, G., & Mannor, S. (2021). Reinforcement Learning for Datacenter Congestion Control. http://arxiv.org/abs/2102.09337
[2] Wu, H., Judd, P., Zhang, X., Isaev, M., & Micikevicius, P. (2020). Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation. http://arxiv.org/abs/2004.09602

integer-only-inference-for-deep-learning-in-native-c's People

Contributors

benja263 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.