Coder Social home page Coder Social logo

jiajinuiuc / modeldiff Goto Github PK

View Code? Open in Web Editor NEW

This project forked from madrylab/modeldiff

0.0 0.0 0.0 29.21 MB

ModelDiff: A Framework for Comparing Learning Algorithms

Home Page: http://gradientscience.org/modeldiff

License: MIT License

Python 0.72% Jupyter Notebook 99.28%

modeldiff's Introduction

ModelDiff: A Framework for Comparing Learning Algorithms

This repository contains the code for ModelDiff, a framework for feature-based comparisons of ML models trained with two different learning algorithms:

ModelDiff: A Framework for Comparing Learning Algorithms
Harshay Shah*, Sung Min Park*, Andrew Ilyas*, Aleksander Madry
Paper: https://arxiv.org/abs/2211.12491
Blog post: http://gradientscience.org/modeldiff/

@inproceedings{shah2022modeldiff,
  title={ModelDiff: A Framework for Comparing Learning Algorithms},
  author = {Harshay Shah and Sung Min Park and Andrew Ilyas and Aleksander Madry},
  booktitle = {ArXiv preprint arXiv:2211.12491},
  year = {2022}
}

Overview

The figure above summarizes our algorithm comparisons framework, ModelDiff.

  • First, our method computes datamodel representations for each algorithm (part A) and then computes residual datamodels (part B) to identify directions (in training set space) that are specific to each algorithm.
  • Then, we run PCA on the residual datamodels (part C) to find a set of distinguishing training directions---weighted combinations of training examples that disparately impact predictions of models trained with different algorithms. Each distinguishing direction surfaces a distinguishing subpopulation, from which we infer a testable distinguishing transformation (part D) that significantly impacts predictions of models trained with one algorithm but not the other.

In our paper, we apply ModelDiff to three case studies that compare models trained with/without standard data augmentation, with/without ImageNet pre-training, and with different SGD hyperparameters. As shown below, in all three cases, our framework allows us to pinpoint concrete ways in which the two algorithms being compared differ:

Getting started

  1. Clone the repo: git clone [email protected]:MadryLab/modeldiff.git

  2. Our code relies on the FFCV Library. To install this library along with other dependencies including PyTorch, follow the instructions below:

        conda create -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge
        conda activate ffcv
    
        cd <REPO-DIR>
        pip install -r requirements.txt
    
  3. Setup datasets. We use CIFAR-10 (torchvision), Waterbirds (WILDS), and Living17 (BREEDS). Also, change the DATA_DIR path in src/data/datasets.py to the parent directory of ImageNet data.

  4. Our framework uses datamodel representations to identify distinguishing features. Download pre-computed datamodels for all three case studies from here and unzip them into datamodels/

That's it! Now you can run notebooks (one corresponding to each case study in analysis/), or take a look at our scripts (in counterfactuals/) that evaluate the average treatment effect of distinguishing feature transformations.

Maintainers

modeldiff's People

Contributors

harshays avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.