Coder Social home page Coder Social logo

melanibe / distance_matters_performance_estimation Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 0.0 2.13 MB

Code for "Distance Matters for Improving Performance Estimation Under Covariate Shift", ICCV Workshop on Uncertainty Quantification 2023, Roschewitz & Glocker.

License: MIT License

Python 6.34% Jupyter Notebook 93.66%
covariate-shift dataset-shift image-classification performance-estimation uncertainty-quantification

distance_matters_performance_estimation's Introduction

Mélanie Roschewitz & Ben Glocker.
Accepted at ICCV - Workshop on Uncertainty Quantification for Computer Vision 2023.

If you like this repository, please consider citing our work

@inproceedings{roschewitz2023distance,
  title={Distance Matters For Improving Performance Estimation Under Covariate Shift},
  author={Roschewitz, M{\'e}lanie and Glocker, Ben},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops},
  pages={4549--4559},
  year={2023}
}

Abstract Performance estimation under covariate shift is a crucial component of safe AI model deployment, especially for sensitive use-cases. Recently, several solutions were proposed to tackle this problem, most leveraging model predictions or softmax confidence to derive accuracy estimates. However, under dataset shifts confidence scores may become ill-calibrated if samples are too far from the training distribution. In this work, we show that taking into account distances of test samples to their expected training distribution can significantly improve performance estimation under covariate shift. Precisely, we introduce a "distance-check" to flag samples that lie too far from the expected distribution, to avoid relying on their untrustworthy model outputs in the accuracy estimation step. We demonstrate the effectiveness of this method on 13 image classification tasks, across a wide-range of natural and synthetic distribution shifts and hundreds of models.

figure_1

This repository contains all the necessary code to reproduce our model evaluation, training and plots. Paper can be found here.

Overview

The repository is divided into the following sub-folders:

  • evaluation contains the most important part of this codebase, defining all necessary tools for accuracy estimation. In particular:

  • classification contains all the necessary code to train and define models, as well as all the code to load specific experimental configurations. The configs/general subfolder contains all training configuration used in this work. Our code is uses PyTorch Lightning and the main classification module is defined in classification_module.py. The main entry point for training models is train_all_models_for_dataset.py to train all models used in the paper for a given task. All the outputs will be placed in [REPO_ROOT] / outputs by default.

  • data_handling contains all the code related to data loading and augmentations.

Prerequisites

  1. Start by cloning our conda environment as specified by the environment_full.yml file as the root of the repository.
  2. Make sure you update the paths to your datasets in default_paths.py.
  3. Make sure the root directory is in your PYTHONPATH environment variable.

Ready to go!

Step-by-step example

In this section, we will walk you through all steps necessary to reproduce the experiments for Living17. The procedure is identical for all other experiments, you just need to change which dataset you want to use.

Assuming your current work directory is the root of the repository:

  1. Train all models (this will take a few days!) for this dataset python classification/train_all_models_for_dataset.py --dataset living17.
  2. You are ready to run the evaluation benchmark with python evaluation/evaluation_confidence_based.py --dataset living17
  3. The outputs can be found in the outputs/{DATASET_NAME}/{MODEL_NAME}/{RUN_NAME} folder. There you will find metrics.csv which contains all predictions and errors for all models for this dataset.
  4. If you then want to reproduce the plots and the aggregated results over all models as in Table 1 in paper, you will need run the evaluation/plotting_notebook.ipynb notebook.

distance_matters_performance_estimation's People

Contributors

melanibe avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.