Coder Social home page Coder Social logo

phuochaihuynh / neural-network-distillation Goto Github PK

View Code? Open in Web Editor NEW

This project forked from tsitsimis/neural-network-distillation

0.0 1.0 0.0 538 KB

Simple demonstration of the Knowledge Distillation framework with neural networks

Python 0.42% Jupyter Notebook 99.58%

neural-network-distillation's Introduction

Knowledge Distillation

Knowledge Distillation describes a general framework of transfering the knowledge of a large and complex model into a smaller and simpler one. The main purpose to do this is to allow for faster inference times (e.g. in a smartphone) without significantly compromising performance (accuracy, square error, etc)

Code based on the paper:

Distilling the Knowledge in a Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean, NIPS 2014 Deep Learning Workshop
[pdf]

Soft labels are regularizers

One of the main problems in machine learning is achieving good generalization on unseen data. And one requirement for this is to have a model with low variance in order to avoid overfit and produce smooth decision functions (classification) or predictions (regression). That's why it's common to use complex ensemble models consisting of multiple learners that are combined to produce the final prediction (boosting, bagging). Or to apply regularization techniques such as dropout. But these models tend to need extra memory and computation time not only during training but evaluation phase too.

Neural network. Test accuracy 90%

This is a very simple example of a learned decision boundary for a binary classification problem with high noise and overlap between classes. This is learned by a neural network with 2 hidden layers (128, 128). The network's parameters are enough to capture the undelying decision function but without having too much freedom overfitting to the high noise of the points.

One could argue that once we have this learned function produced by a "complex" model we can just use it to train again another, simpler model to reproduce it. We basically want our simple model to "overfit" this function, since we know it generalizes well.

We picked a decision tree as our simple model.

  • To the left, we trained a clasification tree using the original hard labels (0/1) and let it grow indefinitely and overfit.
  • To the right, we trained a regression tree (since classification trees can't be trained with soft targets) but using as targets the class probabilities of the neural network. Again, no regularization (pruning) was applied.
Decision Tree trained with hard labels. Test accuracy 86.5% Decision Tree trained with soft labels. Test accuracy 90%

The first decision tree of course doesn't perform well since it is affected by the high noise in the dataset and its decision boundary exhibits artifacts leading to low test accuracy.

But we see how the second tree, without any regularization, converged to the same decision function as the regularized neural network and achieved the same test performance.

Neural Network Knowledge Distillation on MNIST data

We apply the above ideas on the MNIST dataset. The "teacher" model that will be used to produce the soft labels is a neural network with 32 convolutional filters of size 3x3 and 2 wide hidden fully-connected layers of 1200 units each. Also, each hidden layer is followed by a dropout layer. The "student" model is a much smaller network with 2 hidden layers of only 10 units each and no convolution or dropout layers.

A ridiculous visualization of the knowledge distillation method

Model Accuracy Test errors
Teacher 97.5% 248
Student on hard labels 86.3% 1366
Student on soft labels 93.0% 696

The small model achieves only 86.3% accuracy using hard labels. Astonishingly, the same network trained on soft labels achieves 93% accuracy making 50% less errors on test data.

neural-network-distillation's People

Contributors

tsitsimis avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.