Coder Social home page Coder Social logo

magicwyzh / trained-ternary-quantization Goto Github PK

View Code? Open in Web Editor NEW

This project forked from tropcomplique/trained-ternary-quantization

0.0 1.0 0.0 136.66 MB

Reducing the size of convolutional neural networks

License: MIT License

Python 3.00% Jupyter Notebook 97.00%

trained-ternary-quantization's Introduction

Trained Ternary Quantization

pytorch implementation of Trained Ternary Quantization, a way of replacing full precision weights of a neural network by ternary values. I tested it on Tiny ImageNet dataset. The dataset consists of 64x64 images and has 200 classes.

The quantization roughly proceeds as follows.

  1. Train a model of your choice as usual (or take a trained model).
  2. Copy all full precision weights that you want to quantize. Then do the initial quantization:
    in the model replace them by ternary values {-1, 0, +1} using some heuristic.
  3. Repeat until convergence:
    • Make the forward pass with the quantized model.
    • Compute gradients for the quantized model.
    • Preprocess the gradients and apply them to the copy of full precision weights.
    • Requantize the model using the changed full precision weights.
  4. Throw away the copy of full precision weights and use the quantized model.

Results

I believe that this results can be made better by spending more time on hyperparameter optimization.

model accuracy, % top5 accuracy, % number of parameters
DenseNet-121 74 91 7151176
TTQ DenseNet-121 55 79 ~7M 2-bit, 89% are zeros
small DenseNet 49 75 440264
TTQ small DenseNet 37 65 ~0.4M 2-bit, 65% are zeros
SqueezeNet 52 77 827784
TTQ SqueezeNet 36 63 ~0.8M 2-bit, 66% are zeros

Implementation details

  • I use pretrained DenseNet-121, but I train SqueezeNet and small DenseNet from scratch.
  • I modify the SqueezeNet architecture by adding batch normalizations and skip connections.
  • I quantize all layers except the first CONV layer, the last FC layer, and all BATCH_NORM layers.

How to reproduce results

For example, for small DenseNet:

  1. Download Tiny ImageNet dataset and extract it to ~/data folder.
  2. Run python utils/move_tiny_imagenet_data.py to prepare the data.
  3. Go to vanilla_densenet_small/. Run train.ipynb to train the model as usual.
    Or you can skip this step and use model.pytorch_state (the model already trained by me).
  4. Go to ttq_densenet_small/.
  5. Run train.ipynb to do TTQ.
  6. Run test_and_explore.ipynb to explore the quantized model.

To use this on your data you need to edit utils/input_pipeline.py and to change the model architecture in files like densenet.py and get_densenet.py as you like.

Requirements

  • pytorch 0.2, Pilllow, torchvision
  • numpy, sklearn, tqdm, matplotlib

trained-ternary-quantization's People

Contributors

tropcomplique avatar

Watchers

Yizhi Wang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.