Coder Social home page Coder Social logo

danielkelshaw / concretedropout Goto Github PK

View Code? Open in Web Editor NEW
14.0 2.0 2.0 409 KB

PyTorch implementation of 'Concrete Dropout'

Home Page: https://arxiv.org/abs/1705.07832

License: MIT License

Python 100.00%
statistical-machine-learning dropout-probability dropout neural-network

concretedropout's Introduction

ConcreteDropout

Build Status

PyTorch implementation of Concrete Dropout

This repository provides an implementation of the theory described in the Concrete Dropout paper. The code provides a simple PyTorch interface which ensures that the module can be integrated into existing code with ease.

  • Python 3.6+
  • MIT License

Overview:

Obtaining reliable uncertainty estimates is a challenge which requires a grid-search over various dropout probabilities - for larger models this can be computationally prohibitive. The Concrete Dropout paper suggests a novel dropout variant which improves performance and yields better uncertainty estimates.

Concrete Dropout uses the approach of optimising the dropout probability through gradient descent in order to minimise an objective wrt. that parameter. Dropout can be viewed as as an approximating distribution to the posterior, q(w). Using this interpretation it is possible to add a regularisation term to the loss function which is dependant on the KL Divergence, KL[q(w)||p(w)]; this ensures that the posterior does not deviate too far from the prior. As is often the case, the KL Divergence is computationally intractable and as such an approximation is developed - details of this can be seen in equations [2-4] in the paper.

In typical dropout the probability is modelled as a Bernoulli random variable - unfortunately this does not play well with the re-parameterisation trick which is required to calculate the derivative of the objective. To allow the derivative to be calculated, a continous relaxation of the discrete Bernoulli distribution is used - specifically the Concrete distribution relaxation. This has a simple parameterisation which reduces to a simple sigmoid distribution as seen in equation [5].

Through use of the Concrete relaxation it is now possible to compute the derivatives of the objective with help from the re-parameterisation trick and optimise the dropout probability through gradient descent.

Example:

An example of ConcreteDropout has been implemented in mnist_example.py - this example can be run with:

python3 mnist_example.py

MNIST Results

References:

@misc{gal2017concrete,
    title={Concrete Dropout},
    author={Yarin Gal and Jiri Hron and Alex Kendall},
    year={2017},
    eprint={1705.07832},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}
Code by Yarin Gal, author of the paper.
PyTorch implementation of Concrete Dropout
Made by Daniel Kelshaw

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.