Coder Social home page Coder Social logo

seafloor / random-forest-importances Goto Github PK

View Code? Open in Web Editor NEW

This project forked from parrt/random-forest-importances

0.0 1.0 0.0 14.98 MB

Code to compute permutation and drop-column importances in Python scikit-learn random forests

License: MIT License

CSS 0.06% HTML 33.38% Jupyter Notebook 65.60% Python 0.96%

random-forest-importances's Introduction

Feature importances for scikit random forests

By Terence Parr and Kerem Turgutlu. See Explained.ai for more stuff.

The scikit-learn Random Forest feature importances strategy is mean decrease in impurity (or gini importance) mechanism, which is unreliable. To get reliable results, use permutation importance, provided in the rfpimp package in the src dir. Install with:

pip install rfpimp

Description

See Beware Default Random Forest Importances for a deeper discussion of the issues surrounding feature importances in random forests (authored by Terence Parr, Kerem Turgutlu, Christopher Csiszar, and Jeremy Howard).

The mean-decrease-in-impurity importance of a feature is computed by measuring how effective the feature is at reducing uncertainty (classifiers) or variance (regressors) when creating decision trees within random forests. The problem is that this mechanism, while fast, does not always give an accurate picture of importance. Strobl et al pointed out in Bias in random forest variable importance measures: Illustrations, sources and a solution that โ€œthe variable importance measures of Breiman's original random forest method ... are not reliable in situations where potential predictor variables vary in their scale of measurement or their number of categories.โ€

A more reliable method is permutation importance, which measures the importance of a feature as follows. Record a baseline accuracy (classifier) or R2 score (regressor) by passing a validation set or the out-of-bag (OOB) samples through the random forest. Permute the column values of a single predictor feature and then pass all test samples back through the random forest and recompute the accuracy or R2. The importance of that feature is the difference between the baseline and the drop in overall accuracy or R2 caused by permuting the column. The permutation mechanism is much more computationally expensive than the mean decrease in impurity mechanism, but the results are more reliable.

Sample code

See the notebooks directory for things like Collinear features and Plotting feature importances.

Here's some sample Python code that uses the rfpimp package contained in the src directory. The data can be found in rent.csv, which is a subset of the data from Kaggle's Two Sigma Connect: Rental Listing Inquiries competition.

from rfpimp import *
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

df_orig = pd.read_csv("/Users/parrt/github/random-forest-importances/notebooks/data/rent.csv")

df = df_orig.copy()

# attentuate affect of outliers in price
df['price'] = np.log(df['price'])

df_train, df_test = train_test_split(df, test_size=0.20)

features = ['bathrooms','bedrooms','longitude','latitude',
            'price']
df_train = df_train[features]
df_test = df_test[features]

X_train, y_train = df_train.drop('price',axis=1), df_train['price']
X_test, y_test = df_test.drop('price',axis=1), df_test['price']
X_train['random'] = np.random.random(size=len(X_train))
X_test['random'] = np.random.random(size=len(X_test))

rf = RandomForestRegressor(n_estimators=100, n_jobs=-1)
rf.fit(X_train, y_train)

imp = importances(rf, X_test, y_test) # permutation
viz = plot_importances(imp)
viz.view()


df_train, df_test = train_test_split(df_orig, test_size=0.20)
features = ['bathrooms','bedrooms','price','longitude','latitude',
            'interest_level']
df_train = df_train[features]
df_test = df_test[features]

X_train, y_train = df_train.drop('interest_level',axis=1), df_train['interest_level']
X_test, y_test = df_test.drop('interest_level',axis=1), df_test['interest_level']
# Add column of random numbers
X_train['random'] = np.random.random(size=len(X_train))
X_test['random'] = np.random.random(size=len(X_test))

rf = RandomForestClassifier(n_estimators=100,
                            min_samples_leaf=5,
                            n_jobs=-1,
                            oob_score=True)
rf.fit(X_train, y_train)

imp = importances(rf, X_test, y_test, n_samples=-1)
viz = plot_importances(imp)
viz.view()

Feature correlation

See Feature collinearity heatmap. We can get the Spearman's correlation matrix:

Feature dependencies

The features we use in machine learning are rarely completely independent, which makes interpreting feature importance tricky. We could compute correlation coefficients, but that only identifies linear relationships. A way to at least identify if a feature, x, is dependent on other features is to train a model using x as a dependent variable and all other features as independent variables. Because random forests give us an easy out of bag error estimate, the feature dependence functions rely on random forest models. The R^2 prediction error from the model indicates how easy it is to predict feature x using the other features. The higher the score, the more dependent feature x is.

You can also get a feature dependence matrix / heatmap that returns a non-symmetric data frame where each row is the importance of each var to the row's var used as a model target. Example:

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.