Coder Social home page Coder Social logo

modeloriented / randomforestexplainer Goto Github PK

View Code? Open in Web Editor NEW
228.0 13.0 39.0 49.74 MB

A set of tools to understand what is happening inside a Random Forest

Home Page: https://ModelOriented.github.io/randomForestExplainer/

R 100.00%
random-forest cran

randomforestexplainer's Introduction

randomForestExplainer

CRAN status R-CMD-check codecov DOI

A set of tools to understand what is happening inside a Random Forest. A detailed discussion of the package and importance measures it implements can be found here: Master thesis on randomForestExplainer.

Installation

# the easiest way to get randomForestExplainer is to install it from CRAN:
install.packages("randomForestExplainer")

# Or the the development version from GitHub:
# install.packages("devtools")
devtools::install_github("ModelOriented/randomForestExplainer")

Vignette

Cheatsheets

Examples

randomforestexplainer's People

Contributors

durszlaczek avatar hbaniecki avatar kasiakobylinska avatar mayer79 avatar olapaluszynska avatar olivroy avatar pbiecek avatar yue-jiang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

randomforestexplainer's Issues

Depreicated functions in called R package

Below is a warning thrown when min_depth_interaction functions() is used.

interactions_frame <- min_depth_interactions(rf_fit, vars)

Warning messages:
1: funs() was deprecated in dplyr 0.8.0.
ℹ Please use a list of either functions or lambdas:

Simple named list: list(mean = mean, median = median)

Auto named with tibble::lst(): tibble::lst(mean, median)

Using lambdas list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))

ℹ The deprecated feature was likely used in the randomForestExplainer package.
Please report the issue to the authors.

Cannot extract important variables with `accuracy_decrease`

I have the following importance_frame:

importance_frame <- structure(list(variable = structure(1:20, .Label = c(
  "A", "C",
  "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R",
  "S", "T", "V", "W", "Y"
), class = "factor"), mean_min_depth = c(
  1.9761861386314,
  2.5220853029533, 2.15539883255869, 1.61935396654558, 1.45123463631321,
  1.53296953170083, 1.77518115811586, 1.52151167552988, 1.89182019096144,
  2.14429040818413, 1.26326405034901, 1.93502763567771, 1.26898183744519,
  2.02060547195198, 1.54217481302459, 1.67384650439192, 1.5485857685783,
  2.09727178410599, 2.75747046937195, 2.35864404092358
), times_a_root = c(
  23.4,
  5.5, 13.3, 27.9, 39.3, 31.3, 29.7, 34.2, 24.2, 13, 43, 22.7,
  45.3, 16.8, 31.5, 30.1, 33.5, 19.3, 1.75, 14.6
), no_of_nodes = c(
  68.1,
  32.6, 62.2, 103.2, 103.3, 104.7, 75.6, 105.7, 72.4, 64.6, 118.4,
  73.6, 116.6, 74.5, 104.6, 95.6, 103.2, 60.3, 8.875, 36.1
), no_of_trees = c(
  65.1,
  32.3, 59.8, 96.1, 94.7, 99.9, 74.8, 100.6, 69.4, 62.8, 111.2,
  71.2, 108.3, 72.4, 98.8, 90, 97.6, 58.4, 8.875, 35.9
), p_value = c(
  0.669119230058558,
  0.999999783867775, 0.824720803698331, 0.10305110839386, 0.160596787513604,
  0.141119826647113, 0.52735342045046, 0.162403671879659, 0.713272963278132,
  0.817225145266696, 0.0104446472288876, 0.546649197487473, 0.0330726857615005,
  0.672936592800508, 0.0310135225001855, 0.182169849737794, 0.274905137508873,
  0.873388429679101, 1, 0.999021554764331
), gini_decrease = c(
  0.233831386391386,
  0.0886505361305361, 0.185330422910423, 0.358267377067377, 0.401108053058053,
  0.397634655344655, 0.308835228105228, 0.389097318237318, 0.250707615717616,
  0.191033563103563, 0.476535763125763, 0.249038827838828, 0.47133199023199,
  0.243902473082473, 0.372547632367632, 0.33646759018759, 0.382999447219447,
  0.203790450660451, 0.0253906843156843, 0.133164814074814
), accuracy_decrease = c(
  -0.00445119047619048,
  -0.00289380952380952, -0.00482809523809524, -0.00530904761904762,
  0.0051652380952381, 0.00616785714285714, 0.00289238095238095,
  -0.00079095238095238, -0.00239095238095238, -0.00648809523809524,
  0.00383690476190476, -0.00413857142857143, 0.00331214285714286,
  -0.00290619047619048, -0.00131714285714286, -0.0046781746031746,
  0.00534214285714286, -0.00532571428571429, 0, -0.000374047619047619
)), class = "data.frame", .Names = c(
  "variable", "mean_min_depth",
  "times_a_root", "no_of_nodes", "no_of_trees", "p_value", "gini_decrease",
  "accuracy_decrease"
), row.names = c(NA, -20L), na.action = structure(c(
  80L,
  180L
), .Names = c("80", "180"), class = "omit"))

importance_frame
#>    variable mean_min_depth times_a_root no_of_nodes no_of_trees    p_value
#> 1         A       1.976186        23.40      68.100      65.100 0.66911923
#> 2         C       2.522085         5.50      32.600      32.300 0.99999978
#> 3         D       2.155399        13.30      62.200      59.800 0.82472080
#> 4         E       1.619354        27.90     103.200      96.100 0.10305111
#> 5         F       1.451235        39.30     103.300      94.700 0.16059679
#> 6         G       1.532970        31.30     104.700      99.900 0.14111983
#> 7         H       1.775181        29.70      75.600      74.800 0.52735342
#> 8         I       1.521512        34.20     105.700     100.600 0.16240367
#> 9         K       1.891820        24.20      72.400      69.400 0.71327296
#> 10        L       2.144290        13.00      64.600      62.800 0.81722515
#> 11        M       1.263264        43.00     118.400     111.200 0.01044465
#> 12        N       1.935028        22.70      73.600      71.200 0.54664920
#> 13        P       1.268982        45.30     116.600     108.300 0.03307269
#> 14        Q       2.020605        16.80      74.500      72.400 0.67293659
#> 15        R       1.542175        31.50     104.600      98.800 0.03101352
#> 16        S       1.673847        30.10      95.600      90.000 0.18216985
#> 17        T       1.548586        33.50     103.200      97.600 0.27490514
#> 18        V       2.097272        19.30      60.300      58.400 0.87338843
#> 19        W       2.757470         1.75       8.875       8.875 1.00000000
#> 20        Y       2.358644        14.60      36.100      35.900 0.99902155
#>    gini_decrease accuracy_decrease
#> 1     0.23383139     -0.0044511905
#> 2     0.08865054     -0.0028938095
#> 3     0.18533042     -0.0048280952
#> 4     0.35826738     -0.0053090476
#> 5     0.40110805      0.0051652381
#> 6     0.39763466      0.0061678571
#> 7     0.30883523      0.0028923810
#> 8     0.38909732     -0.0007909524
#> 9     0.25070762     -0.0023909524
#> 10    0.19103356     -0.0064880952
#> 11    0.47653576      0.0038369048
#> 12    0.24903883     -0.0041385714
#> 13    0.47133199      0.0033121429
#> 14    0.24390247     -0.0029061905
#> 15    0.37254763     -0.0013171429
#> 16    0.33646759     -0.0046781746
#> 17    0.38299945      0.0053421429
#> 18    0.20379045     -0.0053257143
#> 19    0.02539068      0.0000000000
#> 20    0.13316481     -0.0003740476

And I tried to get the important variables with the following code:

library(randomForestExplainer)
x_measure <- "gini_decrease"
y_measure <- "accuracy_decrease"
important_variables(importance_frame,
  k = 10,
  measures = c(x_measure, y_measure, size_measure)
)

The error I get is this:

Error in `[.data.frame`(rankings, , measures) : 
  undefined columns selected

How can I fix the issue?

Also what is the meaning of negative accuracy_decrease ?

Threshold values

HI
I was wondering if it is possible to compute thresholds or pseudo-thresholds for each variable in the forest, similar to what single trees give. Something like the interaction grid where we sometimes see a clear value where the colors drastically change, but for all variables.

Thank you

`explain_forest` errors, doesn't produce HTML document

Firstly, thank you for this package.

I've discovered a couple of issues.

Using the example from the docs as a reprex:

forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE)
explain_forest(forest, interactions = TRUE)

For me this produces the error Error in file(con, "w") : cannot open the connection. I'm using randomForestExplainer_0.10.1. Six images are produced in Your_forest_explained_files/figure-html but not an HTML file.

Additionally, when setting the path parameter:

  • The function error if the given path doesn't exist
  • The path must be absolute, or an error is produced. Would relative paths be possible?
  • In the case that the given directory does exist, and is provided via an absolute path, the function doesn't use that path - it instead creates another directory in the parent directory with the same name as the provided directory with _files appended to the end.

Any resolve to these issues would be greatly appreciated.

Request to add features

Hello,
Is it possible to customize features in the chart obtained with the function "plot_min_depth_distribution" of your package? I'd like to change colors and adjust title.
Thanks.

Function explain_forest throws error

Running explain_forest on a model trained on a popular eductional data set (German Credit Data) throws the following error:

Quitting from lines 81-82 (Explain_forest_template.Rmd)
Error in [.data.frame(rankings, , measures) :
undefined columns selected

Code to reproduce:

library(tidyverse)
library(randomForest)
#> Warning: Paket 'randomForest' wurde unter R Version 3.4.4 erstellt
#> randomForest 4.6-14
#> Type rfNews() to see new features/changes/bug fixes.
#>
#> Attache Paket: 'randomForest'
#> The following object is masked from 'package:dplyr':
#>
#> combine
#> The following object is masked from 'package:ggplot2':
#>
#> margin
library(randomForestExplainer)
set.seed(123)
credit <- read_csv('http://invidio.drl.pl/files/german_credit.csv')
#> Parsed with column specification:
#> cols(
#> .default = col_character(),
#> default = col_integer(),
#> duration_in_month = col_integer(),
#> credit_amount = col_integer(),
#> installment_as_income_perc = col_integer(),
#> present_res_since = col_integer(),
#> age = col_integer(),
#> credits_this_bank = col_integer(),
#> people_under_maintenance = col_integer()
#> )
#> See spec(...) for full column specifications.

credit <- credit %>%
mutate_if(is.character, as.factor) %>%
mutate(default = as.factor(default))
#> Warning: Paket 'bindrcpp' wurde unter R Version 3.4.4 erstellt

credit_shuffled <- sample_frac(credit, 1)
n <- nrow(credit_shuffled)
n_train <- round(0.8 * n)
train_indices <- sample(1:n, n_train)
credit_train <- credit_shuffled[train_indices,]
credit_test <- credit_shuffled[-train_indices,]

glimpse(credit_train)
#> Observations: 800
#> Variables: 21
#> $ default 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,...
#> $ account_check_status 0 <= ... < 200 DM, no checking acco...
#> $ duration_in_month 12, 21, 24, 24, 12, 18, 36, 48, 18,...
#> $ credit_history critical account/ other credits exi...
#> $ purpose car (new), business, radio/televisi...
#> $ credit_amount 2366, 1572, 3777, 2197, 1412, 866, ...
#> $ savings 500 <= ... < 1000 DM, .. >= 1000 DM...
#> $ present_emp_since 4 <= ... < 7 years, .. >= 7 years, ...
#> $ installment_as_income_perc 3, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 1,...
#> $ personal_status_sex male : divorced/separated, female :...
#> $ other_debtors none, none, none, none, guarantor, ...
#> $ present_res_since 3, 4, 4, 4, 2, 2, 3, 2, 1, 4, 1, 3,...
#> $ property if not A121/A122 : car or other, no...
#> $ age 36, 36, 50, 43, 29, 25, 31, 38, 43,...
#> $ other_installment_plans none, bank, none, none, none, none,...
#> $ housing own, own, own, own, own, own, own, ...
#> $ credits_this_bank 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 1,...
#> $ job management/ self-employed/ highly q...
#> $ people_under_maintenance 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1,...
#> $ telephone yes, registered under the customers...
#> $ foreign_worker yes, yes, yes, yes, yes, yes, yes, ...

credit_model <- randomForest(
default ~ .,
data = credit_train
)

class(credit_model)
#> [1] "randomForest.formula" "randomForest"
explain_forest(credit_model)
#> processing file: Explain_forest_template.Rmd
#> [1] accuracy_decrease and gini_decrease
#> Quitting from lines 81-82 (Explain_forest_template.Rmd)
#> Error in [.data.frame(rankings, , measures): nicht definierte Spalten gewählt
Created on 2018-07-12 by the reprex package (v0.2.0).

Slow mutate_if()

The calculations of {randomForest} models lose a significant amount of time in mutate_if() calls. Replacing them by an explicit mutate() is easy, and gives a clear speed-up:

library(randomForest)
library(randomForestExplainer)

set.seed(12)

fit <- randomForest(Sepal.Width~., data = iris)
fit2 <- ranger(Sepal.Width~., data = iris)

system.time( # 2.1 seconds with mutate_if() -> 0.8 seconds with mutate()
out <- min_depth_distribution(fit)
)

Question about Modification of Plot of Interactions

Hello,
I am working with the function plot_min_depth_interactions() of your package (see section Variable Interactions here https://cran.rstudio.com/web/packages/randomForestExplainer/vignettes/randomForestExplainer.html), and I would like to modify the output. As you know, this function creates a (very interesting) chart of interactions using as default the names of the variables of the dataframes. In this case, I am trying to modify the names I used in the dataset by using labels that are more appropriate for presentation (say for example that I have an interaction that appears as "reg_geo:educ_lev" that I would like to change to "Region:Education" for the effects of the chart, but not for the rest of my code). I was trying to use scale_x_discrete of ggplot, but I am lost: how am I supposed to do this? Should I input all variable names and its labels or just those included in the interaction chart? What order should I follow (that of the most important interactions)? Can you guide me, please? Moreover, can I change colors? I was trying with scale_fill_brewer, but it tells me that I am using a continuous value for a discrete variable.
Thanks for your help.

dtplyr 1.0.0

This is a heads up to let you know that I'm planning to release dtplyr 1.0.0 in the near future. This is a complete rewrite of dtplyr so it generates vastly more performant data.table code, but it did require a complete rework of the API so your existing dtplyr code is unlikely to continue to work.

Compatibility with `ranger` package

The ranger package is a modern popular alternative to randomForest, but explainer does not support it :(

I would appreciate if explainer will work with ranger forests :)

And thanks for a good package!

Speed-up interaction calculations

The interaction calculations are slow. Here is my plan to make it faster:

  • Unify output of getTree() and treeInfo() and adapt code correspondingly. Most functions that distinguish {randomForest} and {ranger} can be removed. Implemented by #41
  • Simplify code basis, e.g., reduce the amount of {dplyr} code.
  • Make interaction calculations faster.

@pbiecek @hbaniecki

fragile approach to getting names of independent variables?

Hello! Thanks so much for this package! I'm learning a ton about making inference from random forest models, and I really appreciate the effort you've put into making this more understandable.

I came across an issue when using your package on a {ranger} model built using {spatialRF} when trying to run randomForestExplainer::plot_predict_interaction(). It seems that the method used by {randomForestExplainer} to get the list of dependent variable names is fragile, and can error out if the formula syntax wasn't used to create the {ranger} model.

For instance, with {ranger}, you can build a model like this:

forest_ranger <- ranger::ranger(x = mtcars[, c("mpg", "disp", "hp", "drat", "wt", "qsec", "vs", "am", "gear", "carb")], y = mtcars[, "cyl"])

Which will then error out when trying to run:

plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")

But it doesn't error out when building the same model using the formula syntax:

forest_ranger <- ranger::ranger(cyl ~ ., data = mtcars)
plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")

The issue arises in this line in {randomForestExplainer}:

if(as.character(forest$call[[2]])[3] == "."){

The {spatialRF} package doesn't build the {ranger} model using the formula syntax, so randomForestExplainer::plot_predict_interaction() won't work on the resulting model:

forest_ranger <- spatialRF::rf(dependent.variable.name = "cyl", 
                               predictor.variable.names = c("mpg", "disp", "hp", "drat", "wt", "qsec", "vs", "am", "gear", "carb"), 
                               data = mtcars)
plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")

I documented this issue and my workaround in the repo for {spatialRF} but I thought I'd add it here, too since it seems like the issue is perhaps more relevant for {randomForestExplainer} and how it captures what the dependent variables are in a {ranger} model.

It looks like, in a {ranger} model, you can get the independent variables directly from the $forest$independent.variable.names component? Maybe this is a more robust way to capture that info for plot_predict_interaction()?

What do you think?

Fail to print `measures`

When generating the report with the function explain_forest(), the script fails to print the measures in the section "Compare importance measures".

I guess the problem is here as the object measures is not assigned nowhere in the document.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.