Coder Social home page Coder Social logo

reeses's Introduction

reeses

reeses is a scikit-learn plugin for piecewise models with learned partitions.

learned partitions or groups can be assigned through tree based or clustering models, the only requirement is the estimator can be fit on the data and exposes a method that can assign the appropriate id of the learned partition for new data. prediction estimators can be optimized within node / leaf through grid search if neccesary

reeses tries to follow the utility patterns in scikit-learn i.e.

  1. sample weight transformation for bootstrapping
  2. joblib for parallelization

quick start

pip install reeses

regression

from sklearn.datasets import load_boston
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression

from reeses import PiecewiseRegressor

data = load_boston()

X = data['data']
y = data['target']

tree = DecisionTreeRegressor(min_samples_leaf=40)
ols = LinearRegression()

model = PiecewiseRegressor(assignment_estimator=tree, prediction_estimator=ols)
model.fit(X, y)

classification

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

from reeses import PiecewiseClassifier

data = load_iris()

X = data['data']
y = data['target']

tree = DecisionTreeClassifier(min_samples_leaf=40)
logit = LogisticRegression()

model = PiecewiseClassifier(assignment_estimator=tree, prediction_estimator=logit)
model.fit(X, y)

ensembling & bootstrapping

reeses will introspect ensemble assignment estimators and maintain the bootstrapped sample fit in each assignment estimator for prediction estimators associated with that assignment estimator.

from sklearn.datasets import load_iris
from sklearn.tree import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from reeses import PiecewiseClassifier

data = load_iris()

X = data['data']
y = data['target']

forest = RandomForestClassifier(n_estimators=10, min_samples_leaf=40)
logit = LogisticRegression()

model = PiecewiseClassifier(assignment_estimator=forest, prediction_estimator=logit)
model.fit(X, y)

clustering assignment

reeses supports arbitary assignment estimators. scikit-learn clustering estimators use the predict method to make assignments. <0.0.3 reeses defaults to the apply method but con be configured of any assignment through the assignment_method attribute. >=0.0.3 defaults to apply if assignment estimator has apply else predict if has predict else raises.

Example

Consider y = sign(x) * (1 + abs(x)) - a linear function with a shock at the origin. This function is a poor candidate for linear methods (nonlinear function) and tree based methods (cannot extrapolate outside the observed bounds). Combining the two we can produce an effective estimator that can extrapolate.

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm

from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import r2_score, mean_absolute_error

from reeses import pieces

create data

observations = 10000
x_train = np.random.uniform(-1, 1, observations)
x_test = np.random.uniform(-2, 2, observations)
y_train = np.sign(x_train) * (1 + np.abs(x_train)) + norm.rvs(size=observations, loc=0, scale=.05)
y_test = np.sign(x_test) * (1 + np.abs(x_test)) + norm.rvs(size=observations, loc=0, scale=.05)

# reshape arrays to 2D
x_train = x_train[:, None]
x_test = x_test[:, None]

training data

Training Data

create models

models = {
    'ols':
    LinearRegression(),
    'tree':
    DecisionTreeRegressor(max_depth=10),
    'piecewise_tree':
    pieces.PiecewiseRegressor(assignment_estimator=DecisionTreeRegressor(
        max_depth=1, criterion='mse'),
                              prediction_estimator=LinearRegression()),
}

grid_params = {
    'tree': {
        'max_depth': range(1, 20)
    },
    'piecewise_tree': {
        'assignment_estimator__max_depth': range(1, 5)
    }
}

fit models and make predictions on test data

predictions = {}
for name, estimator in models.items():
    if name in grid_params:
        predictions[name] = GridSearchCV(
            estimator, grid_params[name],
            n_jobs=-1).fit(x_train, y_train).best_estimator_.predict(x_test)
    else:
        predictions[name] = estimator.fit(x_train, y_train).predict(x_test)

Test Data

Test Data

Results!

OLS

OLS

Decision Tree

Decision Tree

Piecewise Estimator

Piecewise Decision Tree

reeses's People

Contributors

marksweissma avatar

Stargazers

Mohamed Amine Bouzaghrane avatar Timothy Brathwaite avatar

Watchers

James Cloos avatar  avatar Kostas Georgiou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.