Coder Social home page Coder Social logo

cigl's Introduction

Calibrating the Rigged Lottery: Making All Tickets Reliable

RigL main image
Test accuracy and ECE value at different sparsities.

This repository contains official PyTorch implementation for the paper:
Calibrating the Rigged Lottery: Making All Tickets Reliable
Bowen Lei, Ruqi Zhang, Dongkuan Xu, Bani Mallick

Abstract: Although sparse training has been successfully used in various resource-limited deep learning tasks to save memory, accelerate training, and reduce inference time, the reliability of the produced sparse models remains unexplored. Previous research has shown that deep neural networks tend to be over-confident, and we find that sparse training exacerbates this problem. Therefore, calibrating the sparse models is crucial for reliable prediction and decision-making. In this paper, we propose a new sparse training method to produce sparse models with improved confidence calibration. In contrast to previous research that uses only one mask to control the sparse topology, our method utilizes two masks, including a deterministic mask and a random mask. The former efficiently searches and activates important weights by exploiting the magnitude of weights and gradients. While the latter brings better exploration and finds more appropriate weight values by random updates. Theoretically, we prove our method can be viewed as a hierarchical variational approximation of a probabilistic deep Gaussian process. Extensive experiments on multiple datasets, model architectures, and sparsities show that our method reduces ECE values by up to 47.8% and simultaneously maintains or even improves accuracy with only a slight increase in computation and storage burden.

Requirements

  • python3.8 and pytorch: 1.7.0+ (GPU support preferable).
  • Then, make install.

Example Code

Train ResNet-50 with CigL on CIFAR10

make cifar10.ERK.RigL DENSITY=0.1 DPNUM=100 LR=0.1 DECAY=0.6 BS=200 SEED=0

Train ResNet-50 with CigL on CIFAR100

make cifar100.ERK.RigL DENSITY=0.2 DPNUM=10 LR=0.1 DECAY=0.6 BS=200 SEED=0
  • --DENSITY: the density level of the deterministic mask and the sparsity = 1 - density.
  • --DPNUM: the random mask's sparsity = 1/DPNUM.
  • --LR: the initial learning rate.
  • --DECAY: the decay rate of the piecewise constant decay schedule.
  • --BS: the beginning epoch of the weight & mask averaging procedure.
  • --SEED: the random seed.

Modify makefiles/cifar10.mk or makefiles/cifar10.mk to use different model architectures and sparse training methods.

Citation

@inproceedings{
lei2023calibrating,
title={Calibrating the Rigged Lottery: Making All Tickets Reliable},
author={Bowen Lei and Ruqi Zhang and Dongkuan Xu and Bani Mallick},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=KdwnGErdT6}
}

Credits

We built on [Reproducibilty Challenge] RigL.

cigl's People

Stargazers

Dongkuan (DK) Xu avatar

Watchers

Bowen Lei 雷博文 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.