Coder Social home page Coder Social logo

tidytrees's Introduction

tidytrees

Regression and classification trees (e.g. from packages like partykit or rpart) are a very powerful set of statistical learning algorithms.

Nevertheless each tree package has its own way of representing and storing the trees, usually as a nested recursive list with attributes. This makes it very hard to interact with them.

This package provides an interface to convert tree objects from various packages into a “tidy” data frame, with a row for each node showing its defining set of rules and its characteristics.

Installation

You can install the last version of tidytrees with

# install.packages("devtools")
devtools::install_github("bakaburg1/tidytrees")

Simple use

tidytrees exposes the generic function tidy_tree which has a method for various tree objects (see ?tidy_tree for the list supported methods). The output is a tibble with a row of each tree node. For each node the relative rules are reported, plus other information like the node id, the number of observations related to the node in the data from which the model is derived, the depth of the node in the tree.

library(tidytrees)
library(partykit)
library(rpart)

# The function works with partykit...
model <- ctree(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model)
#> # A tibble: 12 x 6
#>    rule                                         id n.obs terminal depth estimate
#>    <chr>                                     <int> <int> <lgl>    <dbl>    <dbl>
#>  1 Species in setosa                             2    50 FALSE        1     3.43
#>  2 Sepal.Length <= 5 & Species in setosa         3    28 FALSE        2     3.20
#>  3 Sepal.Length <= 4.9 & Species in setosa       4    20 TRUE         3     3.14
#>  4 Sepal.Length <= 5 & Sepal.Length > 4.9 &…     5     8 TRUE         3     3.36
#>  5 Sepal.Length > 5 & Species in setosa          6    22 FALSE        2     3.71
#>  6 Sepal.Length <= 5.3 & Sepal.Length > 5 &…     7    12 TRUE         3     3.62
#>  7 Sepal.Length > 5.3 & Species in setosa        8    10 TRUE         3     3.82
#>  8 Species in versicolor, virginica              9   100 FALSE        1     2.87
#>  9 Sepal.Length <= 6.3 & Species in versico…    10    58 FALSE        2     2.74
#> 10 Sepal.Length <= 5.5 & Species in versico…    11    12 TRUE         3     2.47
#> 11 Sepal.Length <= 6.3 & Sepal.Length > 5.5…    12    46 TRUE         3     2.81
#> 12 Sepal.Length > 6.3 & Species in versicol…    13    42 TRUE         2     3.05

# ... and with rpart trees (more models to come)
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model)
#> # A tibble: 8 x 6
#>   rule                                          id n.obs depth terminal estimate
#>   <chr>                                      <dbl> <int> <dbl> <lgl>       <dbl>
#> 1 Species = versicolor,virginica                 2   100     1 FALSE        2.87
#> 2 Species = versicolor,virginica & Sepal.Le…     4    58     2 FALSE        2.74
#> 3 Species = versicolor,virginica & Sepal.Le…     8    12     3 TRUE         2.47
#> 4 Species = versicolor,virginica & Sepal.Le…     9    46     3 TRUE         2.81
#> 5 Species = versicolor,virginica & Sepal.Le…     5    42     2 TRUE         3.05
#> 6 Species = setosa                               3    50     1 FALSE        3.43
#> 7 Species = setosa & Sepal.Length < 5.05         6    28     2 TRUE         3.20
#> 8 Species = setosa & Sepal.Length >= 5.05        7    22     2 TRUE         3.71

The rules can optionally be rendered in a R compatible format, for easy use as data filters, or as list of rules.

library(tidytrees)
library(dplyr)
library(rpart)

model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

# Evaluation friendly rules

out <- tidy_tree(model, eval_ready = T)

iris %>% filter(eval(str2expression(out$rule[3]))) %>% str
#> 'data.frame':    12 obs. of  5 variables:
#>  $ Sepal.Length: num  5.5 4.9 5.2 5 5.5 5.5 5.4 5.5 5.5 5 ...
#>  $ Sepal.Width : num  2.3 2.4 2.7 2 2.4 2.4 3 2.5 2.6 2.3 ...
#>  $ Petal.Length: num  4 3.3 3.9 3.5 3.8 3.7 4.5 4 4.4 3.3 ...
#>  $ Petal.Width : num  1.3 1 1.4 1 1.1 1 1.5 1.3 1.2 1 ...
#>  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 2 2 2 2 2 2 2 2 2 2 ...

# Rules as vectors
out <- tidy_tree(model, rule_as_text = F)

out$rule[3]
#> [[1]]
#> [1] "Species = versicolor,virginica" "Sepal.Length < 6.35"           
#> [3] "Sepal.Length < 5.55"

# Both
out <- tidy_tree(model, rule_as_text = F, eval_ready = T)

out$rule[3]
#> [[1]]
#> [1] "Species %in% c(\"versicolor\", \"virginica\")"
#> [2] "Sepal.Length < 6.35"                          
#> [3] "Sepal.Length < 5.55"

Tree models tend to create explicit, nested rules with redundant components, useful to retain the whole branching information. The package can simplify such rules in order to make them more human-friendly while keeping the minimal necessary set of conditions to identify a partition. The simplified rules are ordered alphabetically to group conditions on the same variables together.

library(tidytrees)
library(dplyr)
library(rpart)

model <- rpart(Sepal.Length ~ Species + Sepal.Width, data = iris)

# Full rules

tidy_tree(model)$rule[5:9]
#> [1] "Species = versicolor,virginica & Species = versicolor"                                            
#> [2] "Species = versicolor,virginica & Species = versicolor & Sepal.Width < 2.75"                       
#> [3] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75"                      
#> [4] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75 & Sepal.Width < 3.05" 
#> [5] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75 & Sepal.Width >= 3.05"

# Simplified rules
tidy_tree(model, simplify_rules = T)$rule[5:9]
#> [1] "Species = versicolor"                                           
#> [2] "Sepal.Width < 2.75 & Species = versicolor"                      
#> [3] "Sepal.Width >= 2.75 & Species = versicolor"                     
#> [4] "Sepal.Width < 3.05 & Sepal.Width >= 2.75 & Species = versicolor"
#> [5] "Sepal.Width >= 3.05 & Species = versicolor"

# It works also on a list of conditions
tidy_tree(model, rule_as_text = F, simplify_rules = T)$rule[5:9]
#> [[1]]
#> [1] "Species = versicolor"
#> 
#> [[2]]
#> [1] "Sepal.Width < 2.75"   "Species = versicolor"
#> 
#> [[3]]
#> [1] "Sepal.Width >= 2.75"  "Species = versicolor"
#> 
#> [[4]]
#> [1] "Sepal.Width < 3.05"   "Sepal.Width >= 2.75"  "Species = versicolor"
#> 
#> [[5]]
#> [1] "Sepal.Width >= 3.05"  "Species = versicolor"

# Can be applied to previously created rules

rules <- tidy_tree(model)$rule[5:9]

simplify_rules(rules)
#> [1] "Species = versicolor"                                           
#> [2] "Sepal.Width < 2.75 & Species = versicolor"                      
#> [3] "Sepal.Width >= 2.75 & Species = versicolor"                     
#> [4] "Sepal.Width < 3.05 & Sepal.Width >= 2.75 & Species = versicolor"
#> [5] "Sepal.Width >= 3.05 & Species = versicolor"

Node predictions

The output contains optionally the predicted value in the node and estimation intervals, with the possibility to chose the interval coverage (default = 95%).

library(tidytrees)
library(dplyr)
library(rpart)

# Intervals for continuous...
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, interval_level = .89)
#> # A tibble: 8 x 8
#>   rule                       id n.obs depth terminal estimate conf.low conf.high
#>   <chr>                   <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#> 1 Species = versicolor,v…     2   100     1 FALSE        2.87     2.83      2.91
#> 2 Species = versicolor,v…     4    58     2 FALSE        2.74     2.69      2.79
#> 3 Species = versicolor,v…     8    12     3 TRUE         2.47     2.38      2.55
#> 4 Species = versicolor,v…     9    46     3 TRUE         2.81     2.76      2.87
#> 5 Species = versicolor,v…     5    42     2 TRUE         3.05     3.00      3.10
#> 6 Species = setosa            3    50     1 FALSE        3.43     3.36      3.49
#> 7 Species = setosa & Sep…     6    28     2 TRUE         3.20     3.14      3.27
#> 8 Species = setosa & Sep…     7    22     2 TRUE         3.71     3.64      3.79

# ... and discrete outcomes
model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, interval_level = .89)
#> # A tibble: 24 x 9
#>    rule              id n.obs depth terminal estimate conf.low conf.high y.level
#>    <chr>          <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl> <chr>  
#>  1 Sepal.Length …     2    52     1 FALSE      0.865   0.765      0.934  setosa 
#>  2 Sepal.Length …     2    52     1 FALSE      0.115   0.0527     0.212  versic…
#>  3 Sepal.Length …     2    52     1 FALSE      0.0192  0.00109    0.0860 virgin…
#>  4 Sepal.Length …     4    45     2 TRUE       0.978   0.901      0.999  setosa 
#>  5 Sepal.Length …     4    45     2 TRUE       0.0222  0.00126    0.0988 versic…
#>  6 Sepal.Length …     4    45     2 TRUE       0       0          0.0624 virgin…
#>  7 Sepal.Length …     5     7     2 TRUE       0.143   0.00805    0.512  setosa 
#>  8 Sepal.Length …     5     7     2 TRUE       0.714   0.349      0.944  versic…
#>  9 Sepal.Length …     5     7     2 TRUE       0.143   0.00805    0.512  virgin…
#> 10 Sepal.Length …     3    98     1 FALSE      0.0510  0.0209     0.103  setosa 
#> # … with 14 more rows

The default intervals are based on the normal approximation for continuous values and on binom.test() for discrete ones. But the estimation function is pluggable, so users can provide their own.

library(tidytrees)
library(dplyr)
library(rpart)

# Quantile intervals for continuous outcomes
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) {
    data.frame(
        estimate = median(values),
        conf.low = quantile(values, (1 - interval_level) / 2),
        conf.high = quantile(values, .5 + interval_level/2)
    )
})
#> # A tibble: 8 x 8
#>   rule                       id n.obs depth terminal estimate conf.low conf.high
#>   <chr>                   <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#> 1 Species = versicolor,v…     2   100     1 FALSE        2.9      2.2       3.50
#> 2 Species = versicolor,v…     4    58     2 FALSE        2.75     2.2       3.4 
#> 3 Species = versicolor,v…     8    12     3 TRUE         2.45     2.08      2.92
#> 4 Species = versicolor,v…     9    46     3 TRUE         2.8      2.2       3.4 
#> 5 Species = versicolor,v…     5    42     2 TRUE         3        2.60      3.80
#> 6 Species = setosa            3    50     1 FALSE        3.4      2.92      4.18
#> 7 Species = setosa & Sep…     6    28     2 TRUE         3.2      2.70      3.6 
#> 8 Species = setosa & Sep…     7    22     2 TRUE         3.7      3.35      4.30

# Bayesian regularized credibility intervals for discrete outcomes
model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) {
    table(values) %>%
        lapply(function(cases) {
            qbeta(
                c(.5, (1 - interval_level) / 2, .5 + interval_level/2),
                cases + 1.1,
                length(values) - cases + 1.1
            ) %>% matrix(nrow = 1) %>% as.data.frame() %>% 
                setNames(c('estimate', 'cred.low', 'cred.high'))
        }) %>% bind_rows()
})
#> # A tibble: 24 x 8
#>    rule                      id n.obs depth terminal estimate cred.low cred.high
#>    <chr>                  <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#>  1 Sepal.Length < 5.45        2    52     1 FALSE      0.855  0.745       0.931 
#>  2 Sepal.Length < 5.45        2    52     1 FALSE      0.126  0.0558      0.232 
#>  3 Sepal.Length < 5.45        2    52     1 FALSE      0.0332 0.00519     0.103 
#>  4 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.962  0.882       0.994 
#>  5 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.0382 0.00599     0.118 
#>  6 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.0170 0.000803    0.0810
#>  7 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.208  0.0353      0.531 
#>  8 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.675  0.349       0.911 
#>  9 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.208  0.0353      0.531 
#> 10 Sepal.Length >= 5.45       3    98     1 FALSE      0.0580 0.0231      0.115 
#> # … with 14 more rows

tidytrees's People

Contributors

bakaburg1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

tidytrees's Issues

Add general accessor methods to common tree features

It would be useful to streamline the access to common tree features, like nodes, node/tree depth, rules, node observations, etc, with a common interface independent from the tree.

These functions would be added to each class file.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.