modeloriented / randomforestexplainer Goto Github PK
View Code? Open in Web Editor NEWA set of tools to understand what is happening inside a Random Forest
Home Page: https://ModelOriented.github.io/randomForestExplainer/
A set of tools to understand what is happening inside a Random Forest
Home Page: https://ModelOriented.github.io/randomForestExplainer/
I have the following importance_frame
:
importance_frame <- structure(list(variable = structure(1:20, .Label = c(
"A", "C",
"D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R",
"S", "T", "V", "W", "Y"
), class = "factor"), mean_min_depth = c(
1.9761861386314,
2.5220853029533, 2.15539883255869, 1.61935396654558, 1.45123463631321,
1.53296953170083, 1.77518115811586, 1.52151167552988, 1.89182019096144,
2.14429040818413, 1.26326405034901, 1.93502763567771, 1.26898183744519,
2.02060547195198, 1.54217481302459, 1.67384650439192, 1.5485857685783,
2.09727178410599, 2.75747046937195, 2.35864404092358
), times_a_root = c(
23.4,
5.5, 13.3, 27.9, 39.3, 31.3, 29.7, 34.2, 24.2, 13, 43, 22.7,
45.3, 16.8, 31.5, 30.1, 33.5, 19.3, 1.75, 14.6
), no_of_nodes = c(
68.1,
32.6, 62.2, 103.2, 103.3, 104.7, 75.6, 105.7, 72.4, 64.6, 118.4,
73.6, 116.6, 74.5, 104.6, 95.6, 103.2, 60.3, 8.875, 36.1
), no_of_trees = c(
65.1,
32.3, 59.8, 96.1, 94.7, 99.9, 74.8, 100.6, 69.4, 62.8, 111.2,
71.2, 108.3, 72.4, 98.8, 90, 97.6, 58.4, 8.875, 35.9
), p_value = c(
0.669119230058558,
0.999999783867775, 0.824720803698331, 0.10305110839386, 0.160596787513604,
0.141119826647113, 0.52735342045046, 0.162403671879659, 0.713272963278132,
0.817225145266696, 0.0104446472288876, 0.546649197487473, 0.0330726857615005,
0.672936592800508, 0.0310135225001855, 0.182169849737794, 0.274905137508873,
0.873388429679101, 1, 0.999021554764331
), gini_decrease = c(
0.233831386391386,
0.0886505361305361, 0.185330422910423, 0.358267377067377, 0.401108053058053,
0.397634655344655, 0.308835228105228, 0.389097318237318, 0.250707615717616,
0.191033563103563, 0.476535763125763, 0.249038827838828, 0.47133199023199,
0.243902473082473, 0.372547632367632, 0.33646759018759, 0.382999447219447,
0.203790450660451, 0.0253906843156843, 0.133164814074814
), accuracy_decrease = c(
-0.00445119047619048,
-0.00289380952380952, -0.00482809523809524, -0.00530904761904762,
0.0051652380952381, 0.00616785714285714, 0.00289238095238095,
-0.00079095238095238, -0.00239095238095238, -0.00648809523809524,
0.00383690476190476, -0.00413857142857143, 0.00331214285714286,
-0.00290619047619048, -0.00131714285714286, -0.0046781746031746,
0.00534214285714286, -0.00532571428571429, 0, -0.000374047619047619
)), class = "data.frame", .Names = c(
"variable", "mean_min_depth",
"times_a_root", "no_of_nodes", "no_of_trees", "p_value", "gini_decrease",
"accuracy_decrease"
), row.names = c(NA, -20L), na.action = structure(c(
80L,
180L
), .Names = c("80", "180"), class = "omit"))
importance_frame
#> variable mean_min_depth times_a_root no_of_nodes no_of_trees p_value
#> 1 A 1.976186 23.40 68.100 65.100 0.66911923
#> 2 C 2.522085 5.50 32.600 32.300 0.99999978
#> 3 D 2.155399 13.30 62.200 59.800 0.82472080
#> 4 E 1.619354 27.90 103.200 96.100 0.10305111
#> 5 F 1.451235 39.30 103.300 94.700 0.16059679
#> 6 G 1.532970 31.30 104.700 99.900 0.14111983
#> 7 H 1.775181 29.70 75.600 74.800 0.52735342
#> 8 I 1.521512 34.20 105.700 100.600 0.16240367
#> 9 K 1.891820 24.20 72.400 69.400 0.71327296
#> 10 L 2.144290 13.00 64.600 62.800 0.81722515
#> 11 M 1.263264 43.00 118.400 111.200 0.01044465
#> 12 N 1.935028 22.70 73.600 71.200 0.54664920
#> 13 P 1.268982 45.30 116.600 108.300 0.03307269
#> 14 Q 2.020605 16.80 74.500 72.400 0.67293659
#> 15 R 1.542175 31.50 104.600 98.800 0.03101352
#> 16 S 1.673847 30.10 95.600 90.000 0.18216985
#> 17 T 1.548586 33.50 103.200 97.600 0.27490514
#> 18 V 2.097272 19.30 60.300 58.400 0.87338843
#> 19 W 2.757470 1.75 8.875 8.875 1.00000000
#> 20 Y 2.358644 14.60 36.100 35.900 0.99902155
#> gini_decrease accuracy_decrease
#> 1 0.23383139 -0.0044511905
#> 2 0.08865054 -0.0028938095
#> 3 0.18533042 -0.0048280952
#> 4 0.35826738 -0.0053090476
#> 5 0.40110805 0.0051652381
#> 6 0.39763466 0.0061678571
#> 7 0.30883523 0.0028923810
#> 8 0.38909732 -0.0007909524
#> 9 0.25070762 -0.0023909524
#> 10 0.19103356 -0.0064880952
#> 11 0.47653576 0.0038369048
#> 12 0.24903883 -0.0041385714
#> 13 0.47133199 0.0033121429
#> 14 0.24390247 -0.0029061905
#> 15 0.37254763 -0.0013171429
#> 16 0.33646759 -0.0046781746
#> 17 0.38299945 0.0053421429
#> 18 0.20379045 -0.0053257143
#> 19 0.02539068 0.0000000000
#> 20 0.13316481 -0.0003740476
And I tried to get the important variables with the following code:
library(randomForestExplainer)
x_measure <- "gini_decrease"
y_measure <- "accuracy_decrease"
important_variables(importance_frame,
k = 10,
measures = c(x_measure, y_measure, size_measure)
)
The error I get is this:
Error in `[.data.frame`(rankings, , measures) :
undefined columns selected
How can I fix the issue?
Also what is the meaning of negative accuracy_decrease
?
Below is a warning thrown when min_depth_interaction functions() is used.
interactions_frame <- min_depth_interactions(rf_fit, vars)
Warning messages:
1: funs()
was deprecated in dplyr 0.8.0.
ℹ Please use a list of either functions or lambdas:
tibble::lst()
: tibble::lst(mean, median)ℹ The deprecated feature was likely used in the randomForestExplainer package.
Please report the issue to the authors.
This is a heads up to let you know that I'm planning to release dtplyr 1.0.0 in the near future. This is a complete rewrite of dtplyr so it generates vastly more performant data.table code, but it did require a complete rework of the API so your existing dtplyr code is unlikely to continue to work.
When generating the report with the function explain_forest()
, the script fails to print the measures
in the section "Compare importance measures".
I guess the problem is here as the object measures
is not assigned nowhere in the document.
Hello! Thanks so much for this package! I'm learning a ton about making inference from random forest models, and I really appreciate the effort you've put into making this more understandable.
I came across an issue when using your package on a {ranger} model built using {spatialRF} when trying to run randomForestExplainer::plot_predict_interaction()
. It seems that the method used by {randomForestExplainer} to get the list of dependent variable names is fragile, and can error out if the formula
syntax wasn't used to create the {ranger} model.
For instance, with {ranger}, you can build a model like this:
forest_ranger <- ranger::ranger(x = mtcars[, c("mpg", "disp", "hp", "drat", "wt", "qsec", "vs", "am", "gear", "carb")], y = mtcars[, "cyl"])
Which will then error out when trying to run:
plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")
But it doesn't error out when building the same model using the formula
syntax:
forest_ranger <- ranger::ranger(cyl ~ ., data = mtcars)
plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")
The issue arises in this line in {randomForestExplainer}:
The {spatialRF} package doesn't build the {ranger} model using the formula
syntax, so randomForestExplainer::plot_predict_interaction()
won't work on the resulting model:
forest_ranger <- spatialRF::rf(dependent.variable.name = "cyl",
predictor.variable.names = c("mpg", "disp", "hp", "drat", "wt", "qsec", "vs", "am", "gear", "carb"),
data = mtcars)
plot_predict_interaction(forest_ranger, mtcars, "mpg", "hp")
I documented this issue and my workaround in the repo for {spatialRF} but I thought I'd add it here, too since it seems like the issue is perhaps more relevant for {randomForestExplainer} and how it captures what the dependent variables are in a {ranger} model.
It looks like, in a {ranger} model, you can get the independent variables directly from the $forest$independent.variable.names
component? Maybe this is a more robust way to capture that info for plot_predict_interaction()
?
What do you think?
Let's move the MI2DataLab/randomForestExplainer package to ModelOriented/randomForestExplainer. All DrWhy projects will be in one repository.
Suggested date for the movement: August 21st
what is the meaning of “tree,variable ,minimal_depth”?
https://cran.rstudio.com/web/packages/randomForestExplainer/vignettes/randomForestExplainer.html
Running explain_forest on a model trained on a popular eductional data set (German Credit Data) throws the following error:
Quitting from lines 81-82 (Explain_forest_template.Rmd)
Error in [.data.frame
(rankings, , measures) :
undefined columns selected
Code to reproduce:
library(tidyverse)
library(randomForest)
#> Warning: Paket 'randomForest' wurde unter R Version 3.4.4 erstellt
#> randomForest 4.6-14
#> Type rfNews() to see new features/changes/bug fixes.
#>
#> Attache Paket: 'randomForest'
#> The following object is masked from 'package:dplyr':
#>
#> combine
#> The following object is masked from 'package:ggplot2':
#>
#> margin
library(randomForestExplainer)
set.seed(123)
credit <- read_csv('http://invidio.drl.pl/files/german_credit.csv')
#> Parsed with column specification:
#> cols(
#> .default = col_character(),
#> default = col_integer(),
#> duration_in_month = col_integer(),
#> credit_amount = col_integer(),
#> installment_as_income_perc = col_integer(),
#> present_res_since = col_integer(),
#> age = col_integer(),
#> credits_this_bank = col_integer(),
#> people_under_maintenance = col_integer()
#> )
#> See spec(...) for full column specifications.
credit <- credit %>%
mutate_if(is.character, as.factor) %>%
mutate(default = as.factor(default))
#> Warning: Paket 'bindrcpp' wurde unter R Version 3.4.4 erstellt
credit_shuffled <- sample_frac(credit, 1)
n <- nrow(credit_shuffled)
n_train <- round(0.8 * n)
train_indices <- sample(1:n, n_train)
credit_train <- credit_shuffled[train_indices,]
credit_test <- credit_shuffled[-train_indices,]
glimpse(credit_train)
#> Observations: 800
#> Variables: 21
#> $ default 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,...
#> $ account_check_status 0 <= ... < 200 DM, no checking acco...
#> $ duration_in_month 12, 21, 24, 24, 12, 18, 36, 48, 18,...
#> $ credit_history critical account/ other credits exi...
#> $ purpose car (new), business, radio/televisi...
#> $ credit_amount 2366, 1572, 3777, 2197, 1412, 866, ...
#> $ savings 500 <= ... < 1000 DM, .. >= 1000 DM...
#> $ present_emp_since 4 <= ... < 7 years, .. >= 7 years, ...
#> $ installment_as_income_perc 3, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 1,...
#> $ personal_status_sex male : divorced/separated, female :...
#> $ other_debtors none, none, none, none, guarantor, ...
#> $ present_res_since 3, 4, 4, 4, 2, 2, 3, 2, 1, 4, 1, 3,...
#> $ property if not A121/A122 : car or other, no...
#> $ age 36, 36, 50, 43, 29, 25, 31, 38, 43,...
#> $ other_installment_plans none, bank, none, none, none, none,...
#> $ housing own, own, own, own, own, own, own, ...
#> $ credits_this_bank 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 1,...
#> $ job management/ self-employed/ highly q...
#> $ people_under_maintenance 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1,...
#> $ telephone yes, registered under the customers...
#> $ foreign_worker yes, yes, yes, yes, yes, yes, yes, ...
credit_model <- randomForest(
default ~ .,
data = credit_train
)
class(credit_model)
#> [1] "randomForest.formula" "randomForest"
explain_forest(credit_model)
#> processing file: Explain_forest_template.Rmd
#> [1] accuracy_decrease and gini_decrease
#> Quitting from lines 81-82 (Explain_forest_template.Rmd)
#> Error in [.data.frame
(rankings, , measures): nicht definierte Spalten gewählt
Created on 2018-07-12 by the reprex package (v0.2.0).
It would be great if parallel processing capabilities were added to the min_depth_distribution()
function. The time required to get the output from min_depth_distribution()
function is substantial for large complex models.
Firstly, thank you for this package.
I've discovered a couple of issues.
Using the example from the docs as a reprex:
forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE)
explain_forest(forest, interactions = TRUE)
For me this produces the error Error in file(con, "w") : cannot open the connection
. I'm using randomForestExplainer_0.10.1
. Six images are produced in Your_forest_explained_files/figure-html
but not an HTML file.
Additionally, when setting the path
parameter:
_files
appended to the end.Any resolve to these issues would be greatly appreciated.
The interaction calculations are slow. Here is my plan to make it faster:
getTree()
and treeInfo()
and adapt code correspondingly. Most functions that distinguish {randomForest} and {ranger} can be removed. Implemented by #41HI
I was wondering if it is possible to compute thresholds or pseudo-thresholds for each variable in the forest, similar to what single trees give. Something like the interaction grid where we sometimes see a clear value where the colors drastically change, but for all variables.
Thank you
It would be nice to be able to change x and y labels like you can with the title
The ranger
package is a modern popular alternative to randomForest
, but explainer does not support it :(
I would appreciate if explainer will work with ranger forests :)
And thanks for a good package!
Hello,
I am working with the function plot_min_depth_interactions() of your package (see section Variable Interactions here https://cran.rstudio.com/web/packages/randomForestExplainer/vignettes/randomForestExplainer.html), and I would like to modify the output. As you know, this function creates a (very interesting) chart of interactions using as default the names of the variables of the dataframes. In this case, I am trying to modify the names I used in the dataset by using labels that are more appropriate for presentation (say for example that I have an interaction that appears as "reg_geo:educ_lev" that I would like to change to "Region:Education" for the effects of the chart, but not for the rest of my code). I was trying to use scale_x_discrete of ggplot, but I am lost: how am I supposed to do this? Should I input all variable names and its labels or just those included in the interaction chart? What order should I follow (that of the most important interactions)? Can you guide me, please? Moreover, can I change colors? I was trying with scale_fill_brewer, but it tells me that I am using a continuous value for a discrete variable.
Thanks for your help.
link to master thesis in README does not work, that is:
Master thesis on randomForestExplainer.
Hello,
Is it possible to customize features in the chart obtained with the function "plot_min_depth_distribution" of your package? I'd like to change colors and adjust title.
Thanks.
Hi,
Is there any way to customise the filename of the explain_forest
outputs?
Your_forest_explained.html
isn't very useful when trying to explore multiple models at once.
The calculations of {randomForest} models lose a significant amount of time in mutate_if()
calls. Replacing them by an explicit mutate()
is easy, and gives a clear speed-up:
library(randomForest)
library(randomForestExplainer)
set.seed(12)
fit <- randomForest(Sepal.Width~., data = iris)
fit2 <- ranger(Sepal.Width~., data = iris)
system.time( # 2.1 seconds with mutate_if() -> 0.8 seconds with mutate()
out <- min_depth_distribution(fit)
)
what is the meaning of this figure?I couldn't understand it,especially what the arrow is pointing to
https://cran.rstudio.com/web/packages/randomForestExplainer/vignettes/randomForestExplainer.html
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.