Coder Social home page Coder Social logo

pygam's Introduction

Build Status PyPI version codecov python27 python36 DOI

pyGAM

Generalized Additive Models in Python.

Tutorial

pyGAM: Getting started with Generalized Additive Models in Python

Installation

pip install pygam

scikit-sparse

To speed up optimization on large models with constraints, it helps to have scikit-sparse installed because it contains a slightly faster, sparse version of Cholesky factorization. The import from scikit-sparse references nose, so you'll need that too.

The easiest way is to use Conda:
conda install -c conda-forge scikit-sparse

scikit-sparse docs

About

Generalized Additive Models (GAMs) are smooth semi-parametric models of the form:

alt tag

where X.T = [X_1, X_2, ..., X_p] are independent variables, y is the dependent variable, and g() is the link function that relates our predictor variables to the expected value of the dependent variable.

The feature functions f_i() are built using penalized B splines, which allow us to automatically model non-linear relationships without having to manually try out many different transformations on each variable.

GAMs extend generalized linear models by allowing non-linear functions of features while maintaining additivity. Since the model is additive, it is easy to examine the effect of each X_i on Y individually while holding all other predictors constant.

The result is a very flexible model, where it is easy to incorporate prior knowledge and control overfitting.

Regression

For regression problems, we can use a linear GAM which models:

alt tag

# wage dataset
from pygam import LinearGAM
from pygam.utils import generate_X_grid

gam = LinearGAM(n_splines=10).gridsearch(X, y)
XX = generate_X_grid(gam)

fig, axs = plt.subplots(1, 3)
titles = ['year', 'age', 'education']

for i, ax in enumerate(axs):
    pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)

    ax.plot(XX[:, i], pdep)
    ax.plot(XX[:, i], confi, c='r', ls='--')
    ax.set_title(titles[i])

Even though we allowed n_splines=10 per numerical feature, our smoothing penalty reduces us to just 14 effective degrees of freedom:

gam.summary()

LinearGAM                                                                                                 
=============================================== ==========================================================
Distribution:                        NormalDist Effective DoF:                                      13.532
Link Function:                     IdentityLink Log Likelihood:                                -24119.2334
Number of Samples:                         3000 AIC:                                            48267.5307
                                                AICc:                                            48267.682
                                                GCV:                                             1247.0706
                                                Scale:                                           1236.9495
                                                Pseudo R-Squared:                                   0.2926
==========================================================================================================
Feature Function   Data Type      Num Splines   Spline Order  Linear Fit  Lambda     P > x      Sig. Code
================== ============== ============= ============= =========== ========== ========== ==========
feature 1          numerical      10            3             False       15.8489    1.63e-03   **        
feature 2          numerical      10            3             False       15.8489    1.50e-11   ***       
feature 3          categorical    5             0             False       15.8489    1.25e-14   ***       
intercept                                                                            1.11e-16   ***       
==========================================================================================================
Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

With LinearGAMs, we can also check the prediction intervals:

# mcycle dataset
from pygam import LinearGAM
from pygam.utils import generate_X_grid

gam = LinearGAM().gridsearch(X, y)
XX = generate_X_grid(gam)

plt.plot(XX, gam.predict(XX), 'r--')
plt.plot(XX, gam.prediction_intervals(XX, width=.95), color='b', ls='--')

plt.scatter(X, y, facecolor='gray', edgecolors='none')
plt.title('95% prediction interval')

And simulate from the posterior:

# continuing last example with the mcycle dataset
for response in gam.sample(X, y, quantity='y', n_draws=50, sample_at_X=XX):
    plt.scatter(XX, response, alpha=.03, color='k')
plt.plot(XX, gam.predict(XX), 'r--')
plt.plot(XX, gam.prediction_intervals(XX, width=.95), color='b', ls='--')
plt.title('draw samples from the posterior of the coefficients')

Classification

For binary classification problems, we can use a logistic GAM which models:

alt tag

# credit default dataset
from pygam import LogisticGAM
from pygam.utils import generate_X_grid

gam = LogisticGAM().gridsearch(X, y)
XX = generate_X_grid(gam)

fig, axs = plt.subplots(1, 3)
titles = ['student', 'balance', 'income']

for i, ax in enumerate(axs):
    pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)

    ax.plot(XX[:, i], pdep)
    ax.plot(XX[:, i], confi[0], c='r', ls='--')
    ax.set_title(titles[i])    

We can then check the accuracy:

gam.accuracy(X, y)

0.97389999999999999

Since the scale of the Binomial distribution is known, our gridsearch minimizes the Un-Biased Risk Estimator (UBRE) objective:

gam.summary()

LogisticGAM                                                                                               
=============================================== ==========================================================
Distribution:                      BinomialDist Effective DoF:                                      4.3643
Link Function:                        LogitLink Log Likelihood:                                  -788.7121
Number of Samples:                        10000 AIC:                                             1586.1527
                                                AICc:                                            1586.1595
                                                UBRE:                                                2.159
                                                Scale:                                                 1.0
                                                Pseudo R-Squared:                                   0.4599
==========================================================================================================
Feature Function   Data Type      Num Splines   Spline Order  Linear Fit  Lambda     P > x      Sig. Code
================== ============== ============= ============= =========== ========== ========== ==========
feature 1          categorical    2             0             False       1000.0     4.41e-03   **        
feature 2          numerical      25            3             False       1000.0     0.00e+00   ***       
feature 3          numerical      25            3             False       1000.0     2.35e-02   *         
intercept                                                                            0.00e+00   ***       
==========================================================================================================
Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Poisson and Histogram Smoothing

We can intuitively perform histogram smoothing by modeling the counts in each bin as being distributed Poisson via PoissonGAM.

# old faithful dataset
from pygam import PoissonGAM

gam = PoissonGAM().gridsearch(X, y)

plt.plot(X, gam.predict(X), color='r')
plt.title('Lam: {0:.2f}'.format(gam.lam))

Custom Models

It's also easy to build custom models, by using the base GAM class and specifying the distribution and the link function.

# cherry tree dataset
from pygam import GAM

gam = GAM(distribution='gamma', link='log', n_splines=4)
gam.gridsearch(X, y)

plt.scatter(y, gam.predict(X))
plt.xlabel('true volume')
plt.ylabel('predicted volume')

We can check the quality of the fit by looking at the Pseudo R-Squared:

gam.summary()

GAM                                                                                                       
=============================================== ==========================================================
Distribution:                         GammaDist Effective DoF:                                      4.1544
Link Function:                          LogLink Log Likelihood:                                   -66.9372
Number of Samples:                           31 AIC:                                              144.1834
                                                AICc:                                             146.7369
                                                GCV:                                                0.0095
                                                Scale:                                              0.0073
                                                Pseudo R-Squared:                                   0.9767
==========================================================================================================
Feature Function   Data Type      Num Splines   Spline Order  Linear Fit  Lambda     P > x      Sig. Code
================== ============== ============= ============= =========== ========== ========== ==========
feature 1          numerical      4             3             False       0.0158     3.42e-12   ***       
feature 2          numerical      4             3             False       0.0158     1.29e-09   ***       
intercept                                                                            7.60e-13   ***       
==========================================================================================================
Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Penalties / Constraints

With GAMs we can encode prior knowledge and control overfitting by using penalties and constraints.

Available penalties:

  • second derivative smoothing (default on numerical features)
  • L2 smoothing (default on categorical features)

Availabe constraints:

  • monotonic increasing/decreasing smoothing
  • convex/concave smoothing
  • periodic smoothing [soon...]

We can inject our intuition into our model by using monotonic and concave constraints:

# hepatitis dataset
from pygam import LinearGAM

gam1 = LinearGAM(constraints='monotonic_inc').fit(X, y)
gam2 = LinearGAM(constraints='concave').fit(X, y)

fig, ax = plt.subplots(1, 2)
ax[0].plot(X, y, label='data')
ax[0].plot(X, gam1.predict(X), label='monotonic fit')
ax[0].legend()

ax[1].plot(X, y, label='data')
ax[1].plot(X, gam2.predict(X), label='concave fit')
ax[1].legend()

API

pyGAM is intuitive, modular, and adheres to a familiar API:

from pygam import LogisticGAM

gam = LogisticGAM()
gam.fit(X, y)

Since GAMs are additive, it is also super easy to visualize each individual feature function, f_i(X_i). These feature functions describe the effect of each X_i on y individually while marginalizing out all other predictors:

pdeps = gam.partial_dependence(X)
plt.plot(pdeps)

Current Features

Models

pyGAM comes with many models out-of-the-box:

  • GAM (base class for constructing custom models)
  • LinearGAM
  • LogisticGAM
  • GammaGAM
  • PoissonGAM
  • InvGaussGAM

You can mix and match distributions with link functions to create custom models!

gam = GAM(distribution='gamma', link='inverse')

Distributions

  • Normal
  • Binomial
  • Gamma
  • Poisson
  • Inverse Gaussian

Link Functions

Link functions take the distribution mean to the linear prediction. These are the canonical link functions for the above distributions:

  • Identity
  • Logit
  • Inverse
  • Log
  • Inverse-squared

Callbacks

Callbacks are performed during each optimization iteration. It's also easy to write your own.

  • deviance - model deviance
  • diffs - differences of coefficient norm
  • accuracy - model accuracy for LogisticGAM
  • coef - coefficient logging

You can check a callback by inspecting:

plt.plot(gam.logs_['deviance'])

Linear Extrapolation

Citing pyGAM

Please consider citing pyGAM if it has helped you in your research or work:

Daniel Servén, & Charlie Brummitt. (2018, March 27). pyGAM: Generalized Additive Models in Python. Zenodo. DOI: 10.5281/zenodo.1208723

BibTex:

@misc{daniel\_serven\_2018_1208723,
  author       = {Daniel Servén and
                  Charlie Brummitt},
  title        = {pyGAM: Generalized Additive Models in Python},
  month        = mar,
  year         = 2018,
  doi          = {10.5281/zenodo.1208723},
  url          = {https://doi.org/10.5281/zenodo.1208723}
}

References

  1. Simon N. Wood, 2006
    Generalized Additive Models: an introduction with R

  2. Hastie, Tibshirani, Friedman
    The Elements of Statistical Learning
    http://statweb.stanford.edu/~tibs/ElemStatLearn/printings/ESLII_print10.pdf

  3. James, Witten, Hastie and Tibshirani
    An Introduction to Statistical Learning
    http://www-bcf.usc.edu/~gareth/ISL/ISLR%20Sixth%20Printing.pdf

  4. Paul Eilers & Brian Marx, 1996 Flexible Smoothing with B-splines and Penalties http://www.stat.washington.edu/courses/stat527/s13/readings/EilersMarx_StatSci_1996.pdf

  5. Kim Larsen, 2015
    GAM: The Predictive Modeling Silver Bullet
    http://multithreaded.stitchfix.com/assets/files/gam.pdf

  6. Deva Ramanan, 2008
    UCI Machine Learning: Notes on IRLS
    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/homework/irls_notes.pdf

  7. Paul Eilers & Brian Marx, 2015
    International Biometric Society: A Crash Course on P-splines
    http://www.ibschannel2015.nl/project/userfiles/Crash_course_handout.pdf

  8. Keiding, Niels, 1991
    Age-specific incidence and prevalence: a statistical perspective

pygam's People

Contributors

cbrummitt avatar dswah avatar habedi avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.