Coder Social home page Coder Social logo

omerronen / imodels-experiments Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yu-group/imodels-experiments

0.0 0.0 0.0 75.25 MB

Experiments with experimental rule-based models to go along with imodels.

Home Page: https://csinva.io/imodels/

Shell 0.01% Python 0.62% Jupyter Notebook 99.37%

imodels-experiments's Introduction


Scripts for easily comparing different aspects of the imodels package. Contains code to reproduce FIGS + Hierarchical shrinkage + G-FIGS.

Documentation

Follow these steps to benchmark a new (supervised) model. If you want to benchmark something like feature importance or unsupervised learning, you will have to make more substantial changes (mostly in 01_fit_models.py)

  1. Write the sklearn-compliant model (init, fit, predict, predict_proba for classifiers) and add it somewhere in a local folder or in imodels
  2. Update configs - create a new folder mimicking an existing folder (e.g. config.interactions)
    1. Select which datasets you want by modifying datasets.py (datasets will be downloaded locally automatically later)
    2. Select which models you want by editing a list similar to models.py
  3. run 01_fit_models.py
    • pass the appropriate cmdline args (e.g. model, dataset, config)
    • example command: python 01_fit_models.py --config interactions --classification_or_regression classification --split_seed 0
    • another ex.: python 01_fit_models.py --config interactions --classification_or_regression classification --model randomforest --split_seed 0
    • running everything: loop over split_seed + classification_or_regression
    • alternatively, to parallelize over a slurm cluster, run 01_submit_fitting.py with the appropriate loops
  4. run 02_aggregate_results.py (which just combines the output of 01_run_comparisons.py into a combined.pkl file across datasets) for plotting
  5. put scripts/notebooks into a subdirectory of the notebooks folder (e.g. notebooks/interactions)

Config

  • When running multiple seeds, we want to aggregate over all keys that are not the split_seed
    • If a hyperparameter is not passed in ModelConfig (e.g. because we are using parial), it cannot be aggregated over seeds later on
      • The extra_aggregate_keys={'max_leaf_nodes': n} is a workaround for this (see configs with partial to understand how it works)

Testing

Tests are run via pytest

Experimental methods

  • Stable rules - finding a stable set of rules across different models

Working methods

FIGS: Fast interpretable greedy-tree sums

๐Ÿ“„ Paper, ๐Ÿ”— Post, ๐Ÿ“Œ Citation

Fast Interpretable Greedy-Tree Sums (FIGS) is an algorithm for fitting concise rule-based models. Specifically, FIGS generalizes CART to simultaneously grow a flexible number of trees in a summation. The total number of splits across all the trees can be restricted by a pre-specified threshold, keeping the model interpretable. Experiments across a wide array of real-world datasets show that FIGS achieves state-of-the-art prediction performance when restricted to just a few splits (e.g. less than 20).

Example FIGS model. FIGS learns a sum of trees with a flexible number of trees; to make its prediction, it sums the result from each tree.

Hierarchical shrinkage: post-hoc regularization for tree-based methods

๐Ÿ“„ Paper (ICML 2022), ๐Ÿ”— Post, ๐Ÿ“Œ Citation

Hierarchical shrinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest). It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter). Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.

HS Example. HS appplies post-hoc regularization to any decision tree by shrinking each node towards its parent.

G-FIGS: Group Probability-Weighted Tree Sums for Interpretable Modeling of Heterogeneous Data

๐Ÿ“„ Paper

Machine learning in high-stakes domains, such as healthcare, faces two critical challenges: (1) generalizing to diverse data distributions given limited training data while (2) maintaining interpretability. To address these challenges, G-FIGS effectively pools data across diverse groups to output a concise, rule-based model. Given distinct groups of instances in a dataset (e.g., medical patients grouped by age or treatment site), G-FIGS first estimates group membership probabilities for each instance. Then, it uses these estimates as instance weights in FIGS (Tan et al. 2022), to grow a set of decision trees whose values sum to the final prediction. G-FIGS achieves state-of-the-art prediction performance on important clinical datasets; e.g., holding the level of sensitivity fixed at 92%, G-FIGS increases specificity for identifying cervical spine injury by up to 10% over CART and up to 3% over FIGS alone, with larger gains at higher sensitivity levels. By keeping the total number of rules below 16 in FIGS, the final models remain interpretable, and we find that their rules match medical domain expertise. All code, data, and models are released on Github.

G-FIGS 2-step process explained.

imodels-experiments's People

Contributors

aagarwal1996 avatar csinva avatar keyan3 avatar omerronen avatar yanshuotan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.