Coder Social home page Coder Social logo

tlapusan / woodpecker Goto Github PK

View Code? Open in Web Editor NEW
8.0 1.0 0.0 4.22 MB

A python library used for tree structure interpretation.

License: MIT License

Jupyter Notebook 98.49% Python 1.51%
machine-learning sklearn scikit-learn decision-trees random-forest visualization

woodpecker's Introduction

Purpose

A python library used for model structure interpretation.
Right now the library contains logic for DecisionTreeClassifier, DecisionTreeRegression and RandomForestClassifier from scikit-learn. Next versions of the library will contain other types of algorithms, like RandomForestRegressor, XGboost.

Becoming a better machine learning engineer is important to understand more deeply the model structure and also to have an intuition of what is happening if we change the model inputs, how these will reflect in model performance. By model inputs we mean to add more data, add new features and to change model hyperparameters

This library was developed with two main ideas in mind :

  • help us better understand the model structure, the model results and based on this to properly choose others hyperparameter values, other set of features for the next iteration
  • to justify/explain the predictions of ML models both for technical and non technical people

How to install ?

pip install git+https://github.com/tlapusan/woodpecker.git

Usage example

Training example

The well known titanic dataset was chosen to show library capabilities.

features = ["Pclass", "Age", "Fare", "Sex_label", "Cabin_label", "Embarked_label"]
target = "Survived"

Let's see some descriptive statistics about training set.

train[features].describe()

Train the model

model = DecisionTreeClassifier(criterion="entropy", random_state=random_state, min_samples_split=20) model.fit(train[features], train[target])

Start using the library

dts = DecisionTreeStructure(model, train, features, target)

Visualize feature importance

You don't have to type all the code needed to extract feature importance, to map them to feature names and to sort them. Now, you just type this simple utility function.

dts.show_features_importance()

Visualize decision tree structure

Like in the above case, this function is also an utility function what wrap all the code needed to visualize decision tree structure using graphviz.

dts.show_decision_tree_structure()

Leaves impurity distribution

Impurity is a metric which shows how confident is your leaf prediction.
In case of entropy, impurity is a range of values between 0 and 1. 0 means that the leaf node is very confident about its predictions, 1 means the opposite.

The tree performance is directly influenced by each leaf performance. So it's very important to have a general overview of how leaves impurity look.

dts.show_leaf_impurity_distribution(bins=40, figsize=(20, 7))

Leaves sample distribution

Sample is a metric which shows how many examples from training set reached that node.
For a leaf is ideal to have an impurity very close to 0, but it's also equally important to have a significant set of samples reaching that leaf. If the set of samples is very small, could be a sign of outfitting for the leaf.

That's why is important to look both at leaves impurity (previous plot) and samples to get a better understanding of tree performance.

dts.show_leaf_samples_distribution(bins=40, figsize=(20, 7))

Individual leaves metrics

There could be the case when we want to investigate individual leaf behavior.
We could analyze leaves with very good, medium or very low performance.

plt.figure(figsize=(40,30))
plt.subplot(3,1,1)
dts.show_leaf_impurity()

plt.subplot(3,1,2)
dts.show_leaf_samples()

plt.subplot(3,1,3) dts.show_leaf_samples_by_class()

Get node samples

This function return a dataframe with all training samples reaching a node. After looking at individual leaves metrics, we can see that there are some interesting leaves. For example the leaf 19 has impurity 0, a lot of samples and all people survived (survived=1) Getting the samples from such a leaf can help us to discover patterns in data or to discover why a leaf has good/bad performance.

dts.get_node_samples(node_id=19)[features + [target]].describe()

We can see that majority of people were from a high social economic status (Pclass = 1), most of them were young to mid age, bought an expensive ticket (mean(Fare) from training is 32) and are all women.

Visualize decision tree path prediction

There will be moments when we need to justify why our model predicted a specific value. Looking at the whole tree and tracking the path prediction is not time effective if the depth of the tree is large.

Let's look at prediction path for the following sample :

Pclass 3.0
Age 28.0
Fare 15.5
Sex_label 0.0
Cabin_label -1.0
Embarked_label 1.0

Visualize decision tree splits path prediction

This visualization shows the training data splits the model was build. It can be used also as a way to learn how decision tree was built.

The sample is the same as above.

dts.show_decision_tree_splits_prediction(sample, bins=20)

For other algorithms visualizations, you can take a look inside the notebooks folder

Release History

  • 0.1
    • model structure investigation for DecisionTreeClassifier
  • 0.2
    • add visualisation for correct/wrong leaves predictions
    • add setup.py file

Meta

Tudor Lapusan
twitter : @tlapusan
email : [email protected]

Library dependencies

  • jupyter
  • matplotlib
  • scikit-learn
  • pandas

License

This project is licensed under the terms of the MIT license, see LICENSE.

woodpecker's People

Contributors

tlapusan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.