Coder Social home page Coder Social logo

ingredients's People

Contributors

adamizdebski avatar harel-harmonic avatar harell avatar hbaniecki avatar jakwisn avatar jmaspons avatar kasiapekala avatar kmatusz avatar konrad-komisarczyk avatar maksymiuks avatar mstaniak avatar pawel99k avatar pbiecek avatar sztach avatar tkonopka avatar tmikolajczyk avatar wojciechkretowicz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ingredients's Issues

plotD3 for partial_dependency

It would be great to make plotD3 work for objects of class aggregated_profiles_explainer which are returned by partial_dependency function.

Color by label

What should be the default behavior? Compare the four options below. IMO the last plot looks best, but even if coloring by default is a bad idea, I think the third plot should have drwhy_theme colors.

Example:

library("DALEX")
library("ingredients")
titanic <- na.omit(titanic)
model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare,
data = titanic, family = "binomial")
explain_titanic_glm <- explain(model_titanic_glm,
data = titanic[,-9],
y = titanic$survived == "yes")
ale_glm <- accumulated_dependency(explain_titanic_glm, N = 50, variables = "age")
pdp_glm <- variable_response(explain_titanic_glm, variable = "age")

library("randomForest")
model_titanic_rf <- randomForest(survived ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic)
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic[,-9],
y = titanic$survived)
ale_rf <- accumulated_dependency(explain_titanic_rf)
pdp_rf <- variable_response(explain_titanic_rf, variable = "age")

plot(ale_glm, ale_rf)
plot(ale_glm, ale_rf) +
theme_drwhy()
plot(ale_glm, ale_rf, color = "label") +
theme_drwhy()
plot(pdp_glm, pdp_rf)

Error in ceteris_paribus.default(x, data, predict_function = predict_function, : promise already under evaluation: recursive default argument reference or earlier problems?

Trying to run the demo code (below) gives me this error

set.seed(1313)
titanic_small <- titanic[sample(1:nrow(titanic), 500), c(1,2,3,6,7,9)]

model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare + class + sibsp,
                         data = titanic_small, family = "binomial")

explain_titanic_glm <- DALEX::explain(model_titanic_glm,
                               data = titanic_small[,-6],
                               y = titanic_small$survived == "yes",
                               label = "glm")

new_observations <- titanic_small[1:4,-6]
rownames(new_observations) <- c("Lisa", "James", "Thomas", "Nancy")

ceteris_paribus <<- ingredients::ceteris_paribus
dime::modelStudio(explain_titanic_glm,
            new_observations,
            facet_dim = c(2,2), N = 200, B = 20, time = 0)

This is my session:

> sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server x64 (build 14393)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C                           LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] DALEX_0.4.4 dime_0.1.1 

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.2        compiler_3.6.1    pillar_1.4.2      ingredients_0.3.3 prettyunits_1.0.2 remotes_2.1.0     tools_3.6.1      
 [8] testthat_2.1.1    digest_0.6.20     pkgbuild_1.0.3    pkgload_1.0.2     memoise_1.1.0     tibble_2.1.3      gtable_0.3.0     
[15] pkgconfig_2.0.2   rlang_0.4.0       cli_1.1.0         rstudioapi_0.10   curl_3.3          dplyr_0.8.3       withr_2.1.2      
[22] desc_1.2.0        fs_1.3.1          devtools_2.1.0    tidyselect_0.2.5  rprojroot_1.3-2   grid_3.6.1        glue_1.3.1       
[29] R6_2.4.0          processx_3.4.1    sessioninfo_1.1.1 purrr_0.3.2       callr_3.3.1       ggplot2_3.2.0     magrittr_1.5     
[36] backports_1.1.4   scales_1.0.0      ps_1.3.0          usethis_1.5.1     assertthat_0.2.1  colorspace_1.4-1  lazyeval_0.2.2   
[43] munsell_0.5.0     crayon_1.3.4     

Thank you!

aspect_importance function

I've added experimental version of aspect importance function to forked repo. Fork contains new aspect_importance functions, vignette and tests.

forked repo

ceteris_paribus: number of rows of the result

library(DALEX)
library(randomForest)
data('apartments')
m_rf_ap <- randomForest(m2.price ~., data = apartments[1:500, ], ntree = 10)
ap_explainer <- explain(m_rf_ap, apartments[1:500, ])

n_points <- min(101, length(unique(ap_explainer$data[, "construction.year"])))

ceteris <- ingredients::ceteris_paribus(
  ap_explainer, apartments[10, ], grid_points = n_points,
  variables = "construction.year")[, c("construction.year", "_yhat_", "_label_")]
nrow(ceteris)
n_points

Where is the difference coming from?

Joint Variable Importance

In variable_importance It would be nice to be able to define a group of variables to be perturbed at once. This would be useful for example to assess group of features engineered from one underlying variable.

Fix documentation

E.g. plot.aggregated_profiles_explainer description says, that it is show_aggregated_profiles.

Change file name plot_aggregated_ceteris_paribus_explainer.R to plot_aggregated_profiles_explainer.R ?

aggregate_profiles description says, that it is plotting, while returning "aggregated_profiles_explainer","data.frame" object.

#' @param only_numerical a logical. If TRUE then only numerical variables will be plotted. If FALSE then only categorical variables will be plotted.

naming of JS files

maybe featureImportanceMultiPlot.js instead of featureImportance2.js ?
Numbers like 2 in names are meaningless

only_numerical param name (plot.ceteris_paribus_explainer)

while the description is clear, the name is confusing. If it's called only_numerical, shouldn't it plot both categorical and numerical features when set to FALSE? Alternatively, it could be called "feature_type" with "numerical"/"categorical" options?

'theme_drwhy_colors()' is now deprecated

Hi,

If I want to plot aggregated profiles, e.g.:

library(DALEX)
library(ingredients)
titanic <- na.omit(titanic)
model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare,
                         data = titanic, family = "binomial")
explain_titanic_glm <- explain(model_titanic_glm,
                               data = titanic[,-9],
                               y = titanic$survived == "yes",
                               verbose = FALSE)
pdp_rf_p <- partial_dependency(explain_titanic_glm, N = 50)
plot(pdp_rf_p)

I get the following message:
Warning message:
Please note that 'theme_drwhy_colors()' is now deprecated, it is better to use 'colors_discrete_drwhy()' instead.
I think it is beacuse of the changes in the latest version of DALEX package. In plot.aggregated_profiles_explainer function theme_drwhy_colors occurs twice.

DALEXverse 0.19.8 release summer 2019

DALEXverse 0.19.8 release summer 2019

Integration

  • readability: vignettes
  • readability: NEWS
  • readability: DESCRIPTION
  • consistency: pkgdown website
  • consistency: entry at DrWhy.AI webpage

assigned: @pbiecek

Code review

  • consistency: names of functions
  • consistency: names of files
  • consistency: names of variables in functions (local and global)
  • length: functions
  • readability: code (comments, constructions)

assigned: @hbaniecki

Feature review

  • readability: documentation (title, description, details)
  • readability: examples (relevant, complete, with comments)
  • reproducibility: tests (code coverage)
  • links to functions: \code

assigned: @AdamIzdebski

Vignette doesn't work

When I try to open a vignette from RStudio, it returns an error with description that it wasn't able to find any vignettes.

Error in description of ceteris paribus

library(randomForest)
model_rm <- randomForest(life_length ~., data = DALEX::dragons, ntree = 200)
explainer <- DALEX::explain(model_rm)
pred_cp <- ingredients::ceteris_paribus(explainer,
                                        new_observation = explainer$data[1,])
ingredients::describe(pred_cp, variables = 'weight')

RandomForest predicts that for the selected instance predicts that , prediction is equal to 1385.405. (...)

bug in feature_importance when using variable groups and data as a matrix

Currently feature_importance works with a non-null list of variable_groups when data is a data frame, but not when data is a matrix. Some models (xgboost) require a matrix, so it is not possible to work around this on the user side. It would be great to make the variable_groups work with matrices as well.

I can adjust the code and add a test.

argument names in `plotD3`

each argument of the plotD3() function should be described.
for argument names we use underscore, so scale_height instead of scaleHeight

in the latest DALEX we should use apartments_test instead of apartmentsTest

Error when data has one variable

I'm running the following code:

set.seed(17)
x <- runif(1000, -8, 8)
y <- ifelse(x <= 1, (x + 3)^2 - 10 , -5*x + 11)
molnar <- data.frame(x = x, y = y)
ggplot(molnar, aes(x, y)) +
  geom_line(size = 2) +
  theme_bw()

true_model <- function(model, newdata) {
  ifelse(newdata$x <= 1, (newdata$x + 3)^2 - 10 , -5*newdata$x + 11)
}

true_model(list(), data.frame(x = c(0, 10)))

library(DALEX)
molnar_explainer <- explain(list(), molnar, y = y, predict_function = true_model)
predict(molnar_explainer, data.frame(x = c(0, 10)))

plot(ingredients::partial_dependency(molnar_explainer, variables = "x"))
ingredients::ceteris_paribus(molnar_explainer, new_observation = data.frame(x = -6),
                             variables = "x")

And get the error:
Error in new_data[, variable] <- rep(split_points, nrow(data)) :
incorrect number of subscripts on matrix

I guess adding , drop = FALSE would fix the problem (so does adding another variable).

Convention

Ceteris paribus predict_function may return many columns (matrix) while explainer's predict_function returns only a vector.

Thus passing arguments to ceteris paribus by hand, not by the explainer, has got another functionality.

missing factorMerger

Hi!

I've noticed that the function factorMerger used to analyse categorical/factor variables was deprecated, and the proposed function does not provide the same results. Is there a reason for this?

many thanks!

tests for aspect_importance

Running the test suite on my machine produces an error from one of the tests in file test_aspect_importance.R, line 57. I am confused because travis reports the package is building correctly. @kasiapekala, can you help? Thanks.

Here are some of the objects within the test.

new_observation

     construction.year surface floor no.rooms district
1002              1978     112     9        4  Mokotow

aspect_importance_ap

            aspects importance
5          district   279.6734
4             floor  -205.1207
2             space  -120.3610
3 construction.year   116.6933

Add non numerical aggregate_profiles

library(DALEX)
library(ingredients)
library("randomForest")
 model_titanic_rf <- randomForest(survived ~ gender + age + class + embarked +
                                    fare + sibsp + parch,  data = titanic)
 model_titanic_rf

 explain_titanic_rf <- explain(model_titanic_rf,
                           data = titanic[,-9],
                           y = titanic$survived)

selected_passangers <- select_sample(titanic, n = 100)
cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers)
  • accumulated/conditional/partial_dependency do not have only_numerical argument:
partial_dependency(explain_titanic_rf, N=50, variables = "gender", only_numerical = FALSE)
  • should aggregate_profiles and *_dependency have only_numerical?
  • this could be bar plot:
pdp_rf_p <- aggregate_profiles(cp_rf, variables = "gender", type = "partial", only_numerical = FALSE)
plot(pdp_rf_p)

partial_dependency failing when provided an explainer object with model.matrix data

Reproducing the example for using DALEX with xgboost found here: https://pbiecek.github.io/DALEX/articles/DALEX_and_xgboost.html but replacing single_variable with ingredients::partial_dependency() (as the single_variable function itself throws a deprecation warning) produces an error.

library(dplyr)
library(xgboost)
library(DALEX)
library(ingredients)

wine <- breakDown::wine

model_matrix_train <- model.matrix(quality ~ . - 1, wine)
data_train <- xgb.DMatrix(model_matrix_train, label = wine$quality)
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
              objective = "reg:linear")

wine_xgb_model <- xgb.train(param, data_train, nrounds = 50)

explainer_xgb <- explain(
  wine_xgb_model, 
  data = model_matrix_train, 
  y = wine$quality, 
  label = "xgboost"
)

pdp_old <- single_variable(
  explainer_xgb, 
  variable = "alcohol", 
  type = "pdp"
) # works

pdp_new <- partial_dependency(
  explainer_xgb, 
  variables = "alcohol"
) # fails

Results in message:

Error in 1:nrow(new_observation) : argument of length 0
In addition: Warning messages:
1: In profiles$`_label_` <- label : Coercing LHS to a list
2: In new_observation$`_yhat_` <- predict_function(x, new_observation) :
  Coercing LHS to a list

I managed to track this down to

new_observation$`_yhat_` <- predict_function(x, new_observation)
. This line converts new_observation to a list.
Thus a couple of lines below, calling 1:nrow(new_observation) results in an error.

I don't know whether partial_dependency is intentionally meant to work only with explainer objects with data frames within them. However it seems that at least a check of the data type should be added and some more informative error message should be thrown.

Session info:

> sessionInfo()
R version 3.4.4 (2018-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.2 LTS

Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=bg_BG.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=bg_BG.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=bg_BG.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=bg_BG.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] ingredients_0.3.3 DALEX_0.4         xgboost_0.81.0.1  dplyr_0.8.0.1    

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.0        rstudioapi_0.9.0  magrittr_1.5      tidyselect_0.2.5  munsell_0.5.0     colorspace_1.4-0  lattice_0.20-35  
 [8] R6_2.4.0          rlang_0.3.1       plyr_1.8.4        tools_3.4.4       grid_3.4.4        data.table_1.12.0 gtable_0.2.0     
[15] lazyeval_0.2.1    yaml_2.2.0        assertthat_0.2.0  tibble_2.0.1      crayon_1.3.4      pdp_0.7.0         Matrix_1.2-12    
[22] gridExtra_2.3     purrr_0.3.0       ggplot2_3.1.0     glue_1.3.0        breakDown_0.1.6   stringi_1.3.1     compiler_3.4.4   
[29] pillar_1.3.1      scales_1.0.0      pkgconfig_2.0.2  

win builder NOTES

examples for describe() takes to much time

** running examples for arch 'i386' ... [73s] NOTE
Examples with CPU or elapsed time > 10s
         user system elapsed
describe   10   0.08    10.1
** running examples for arch 'x64' ... [69s] NOTE
Examples with CPU or elapsed time > 10s
          user system elapsed
describe 10.61   0.13   10.74

feature_importance with more permutations

Hi ModelOriented.

Very useful collection of tools in this package ecosystem. Thank you.

I came across ingredients because of the feature_importance function. It works well based on a single permutation, but the variability between runs is sometimes noticeable on small datasets. For example, runs on the Titanic dataset can disagree on the importance ordering of the second- and third-best features.

Would you be interested in including a new argument to set the number of permutations in feature_importance? The function could output the average dropout loss over those permutations. Returning averages would be compatible with the existing output format and hence with the rest of the package, for example, plots. I can send a pull request in this direction.

NaN in accumulated_dependency output

library("DALEX")

titanic <- na.omit(titanic)
set.seed(1313)
titanic_small <- titanic[sample(1:nrow(titanic), 500), c(1,2,3,6,7,9)]

model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare + class + sibsp,
                         data = titanic_small, family = "binomial")

explain_titanic_glm <- explain(model_titanic_glm,
                               data = titanic_small[,-6],
                               y = titanic_small$survived == "yes",
                               label = "glm")

ad_n <- ingredients::accumulated_dependency(explain_titanic_glm, only_numerical = TRUE, N = 10)
ad_c <- ingredients::accumulated_dependency(explain_titanic_glm, only_numerical = FALSE, N = 10)

More often than not ad_n and ad_c will contain NaN values in _yhat_ column.
Probably because of low number of observations N. Something could be done about that.

error in ceteris_paribus.default with a response with several labels

When you have a response variable with several labels and provide a data.frame with several rows in ceteris_paribus, it returns an error like this:

Error in `$<-.data.frame`(`*tmp*`, "_ids_", value = c(1L, 1L, 1L, 1L,  : 
  replacement has 400 rows, data has 80

This can be corrected with the following change:
In the file ceteris.paribus.R, line 146 is:
new_observation_ext$`_ids_` <- rep(1:nrow(new_observation), each = nrow(new_observation))

it should be:
new_observation_ext$`_ids_` <- rep(1:nrow(new_observation), each = length(col_yhat))

Toy example:

library("randomForest")
library("DALEX")

titanic <- na.omit(titanic)
# we predict embarked instead of survived for generating the bug
model_titanic_rf <- randomForest(embarked ~ gender + age + class + survived +
                                   fare + sibsp + parch,  data = titanic)
model_titanic_rf

explain_titanic_rf <- explain(model_titanic_rf,
                              data = titanic[,-4],
                              y = titanic$survived,
                              label = "Random Forest v7")

# select few passangers
selected_passangers <- select_sample(titanic, n = 20)
cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers)
cp_rf

plot for categorical variables

currently it looks like this

library("DALEX")
titanic <- na.omit(titanic)
model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare,
                         data = titanic, family = "binomial")

explain_titanic_glm <- explain(model_titanic_glm,
                               data = titanic[,-9],
                               y = titanic$survived == "yes")
cp_rf <- ceteris_paribus(explain_titanic_glm, titanic[1,])
plot(cp_rf, only_numerical = FALSE)

(so a line plot without the original labels)

Would it be a good idea to make it

  • bar plot for unordered factors,
  • line plot for ordered factors
    and keep the labels on x-axis?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.