Coder Social home page Coder Social logo

nvlabs / pac-bayes-conformal-prediction Goto Github PK

View Code? Open in Web Editor NEW
16.0 7.0 1.0 3.11 MB

PAC-Bayes generalization certificates for ICP

License: Other

Python 2.69% Jupyter Notebook 97.31%
conformal-prediction pac-bayes uncertainty-quantification

pac-bayes-conformal-prediction's Introduction

PAC Bayes Generalization Certificates for Inductive Conformal Prediction

This repository contains source code to reproduce the results from the paper:

  • Apoorva Sharma, Sushant Veer, Asher Hancock, Heng Yang, Marco Pavone, Anirudha Majumdar, "PAC-Bayes Generalization Certificates for Inductive Conformal Prediction." In Conf. on Neural Information Processing Systems, volume 36, 2023 (in press)

The source code is split into two python modules:

  • confpred, which contains generic utilities to turn pytorch modules into set-predictors via ICP, including using the PAC-Bayes algorithm in this paper.
  • confpred_eval which contains code specific to the evaluation in this paper, e.g. defining datasets, etc.

Installation

To install the package and dependencies, run the following from the root of the repository, ideally after setting up a virtual environment:

pip install -e .

Running experiments

After installing the package, enter the experiments directory to run the experiments. Applying ICP involves three steps, which are implemented with separate scripts.

  1. Training a base predictor for a task:
python 0_train_model.py experiment=<EXPT>
  1. Tuning and calibrating a model on calibration data:
python 1_calibrate_model.py experiment=<EXPT> calibrate=<METHOD> calibrate.alpha=<ALPHA> calibrate.delta=<DELTA> calibrate.alpha_hat=<ALPHA_HAT>
  1. Evaluating the calibrated set-valued predictor on test data:
python 2_eval_model.py experiment=<EXPT> calibrate=<METHOD> calibrate.alpha=<ALPHA> calibrate.delta=<DELTA> calibrate.alpha_hat=<ALPHA_HAT>

The eval script looks for a predictor calibrated using the same command line arguments -- ensure that these match those used when running 1_calibrate_model.py

<EXPT> specifies the experiment to run:

  • toy runs the 1d regression task
  • mnist runs the corrupted mnist classification task

<METHOD> can take three values:

  • confpred for standard, non-optimized ICP;
  • learned, which follows Stutz et al, 2021 to optimize predictors for efficiency on a portion of the data
  • pacbayes which implements our method

The .vscode/launch.json file has commands to automatically run the parameter sweeps used to generate results in the paper. The command line arguments specify overrides to the hydra configuration specified in experiments/conf/.

Visualizing results

The Jupyter notebooks experiments/mnist_results.ipynb and experiments/regression_results.ipynb contain analysis and plotting code that was used to generate all the figures in the experiments. To reproduce results, ensure that all the commands in the vscode launch file for the corresponding to the experiment have been run.

Visualizing theory

The experiments/theory_figs.ipynb contains code to visualize the KL bound derived from our theory.

pac-bayes-conformal-prediction's People

Contributors

borisivanovic avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

valeman

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.