Coder Social home page Coder Social logo

ternary-quantization's Introduction

Ternary quantization

Training models with ternary quantized weights. PyTorch implementation of https://arxiv.org/abs/1612.01064

Work in progress

  • Train MNIST model in original format (float32)
  • Train MNIST model with quantized weights
  • Add training logs
  • Analyze quantized weights
  • Quantize weights keeping w_p and w_n fixed

Repo Guide

  • A simple model (model_full) defined in model.py was trained on MNIST data using full precision weights. The trained weight is stored as weights/original.ckpt.
    • Code for training can be found under main_original.py.
  • A copy of the above model (loaded with trained weights) was created (model_to_quantify) and was trained using quantization. The trained weight is stored as weights/quantized.ckpt.
    • Code for training can be found under main_ternary.py. The logs can be found inside the file logs/quantized_wp_wn_trainable.txt.
  • I also tried updating the weights by an equal amount in the direction of their gradients. In other words, I took the sign of every parameter's gradient and updated the parameter by a small value (0.001) like so: param.grad.data = torch.sign(param.grad.data) * 0.001
    • I got decent results but didn't dig deeper into it. The weights for this model are weights/autoquantize.ckpt.

Notes:

  • Full precision model gives an accuracy of 98.8%
  • Quantized model gives an accuracy of as high as 98.52%
    • I slightly changed the way gradients are calculated. Using mean instead of sum in lines 15 an 16, quantification.py gave better results:
    w_p_grad = (a * grad_data).mean() # not (a * grad_data).sum()
    w_n_grad = (b * grad_data).mean() # not (b * grad_data).sum()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.