Coder Social home page Coder Social logo

sachinruk / torchlife Goto Github PK

View Code? Open in Web Editor NEW
16.0 3.0 1.0 2.52 MB

Survival analysis using Deep Learning.

Home Page: https://sachinruk.github.io/torchlife

License: Apache License 2.0

Jupyter Notebook 91.45% Python 8.44% Makefile 0.11%
survival-analysis pytorch fastai nbdev

torchlife's Introduction

TorchLife

Survival Analysis using pytorch

This library takes a deep learning approach to Survival Analysis.

What is Survival Analysis

A lot of classification problems are actually survival analysis problems and haven't been tackled as such. For example, consider a cancer patient and you take X-ray data from that patient. Over time, patients will eventually die from cancer (lets ignore the case where people will die from other diseases). The usual approach is to say here is the X-ray (x) and will the patient die in the next 30 days or not (y).

Survival analysis instead asks the question given the input (x) and a time(t), what is the probability that a patient will survive for a time greater than t. Considering the training dataset if a patient is still alive, in the classification case it would be thought of as y = 0. In survival analysis we say that it is a censored observation since the patient will die at a certain time in the future when the experiment is not being conducted.

The above analogy can be thought of in other scenarios such as churn prediction as well.

A proper dive into theory can be seen here.

What's with the name?

Well, if you torch a life... you probability wouldn't survive. 😬

How to use this library

There are 3 models in here that can be used.

  1. [Kaplan Meier Model]
  2. [Proportional Hazard Models]
  3. [Accelerated Failure Time Models] All 3 models require you to input a pandas dataframe with the columns "t", "e" indicating time elapsed and a binary variable, event if a death (1) or live (0) instance is observed respectively. They are all capable of doing fit and plot_survival_function.

Kaplan Meier Model

This is the most simplistic model.

Proportional Hazard Model

This model attempts to model the instantaneous hazard of an instance given time. It does this by binning time and finding the cumulative hazard upto a given point in time. It's extention the cox model, takes into account other variables that are not time dependant such that the above mentioned hazard can grow or shrink proportional to the risk associated with non-temporal feature based hazard.

from torchlife.model import ModelHazard

model = ModelHazard('cox')
model.fit(df)
inst_hazard, surv_probability = model.predict(df)

Accelerated Failure Time Models

This model attempts to model (the mode and not average) time such that it is a function of the non-temporal features, x.

You are free to choose the distribution of the error, however, Gumbel distribution is a popular option in SA literature. The distribution needs to be over the real domain and not just positive since we are modelling log time.

from torchlife.model import ModelAFT

model = ModelAFT('Gumbel')
model.fit(df)
surv_prob = model.predict(df)
mode_time = model.predict_time(df)

Kudos

Special thanks to the following libraries and resources.

Install

pip install torchlife

How to use

We need a dataframe that has a column named 't' indicating time, and 'e' indicating a death event.

import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
t e fin age race wexp mar paro prio
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3
from torchlife.model import ModelHazard

model = ModelHazard('cox', lr=0.5)
model.fit(df)
Ξ», S = model.predict(df)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type               | Params
--------------------------------------------
0 | base | ProportionalHazard | 12    
--------------------------------------------
12        Trainable params
0         Non-trainable params
12        Total params
Epoch 0:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 3/4 [00:00<00:00, 25.43it/s, loss=nan, v_num=49]
Epoch 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 16.39it/s, loss=nan, v_num=49]
Epoch 1:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 3/4 [00:00<00:00, 23.74it/s, loss=nan, v_num=49]
Epoch 1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 15.25it/s, loss=nan, v_num=49]
Epoch 2:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 3/4 [00:00<00:00, 22.98it/s, loss=nan, v_num=49]
Epoch 2: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 15.53it/s, loss=nan, v_num=49]
Epoch 3:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 3/4 [00:00<00:00, 20.23it/s, loss=nan, v_num=49]
Validating: 0it [00:00, ?it/s]οΏ½[A
Epoch 3: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 13.30it/s, loss=nan, v_num=49]
                                                         οΏ½[ASaving latest checkpoint...
Epoch 3: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 12.92it/s, loss=nan, v_num=49]

Let's plot the survival function for the 4th element in the dataframe:

x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)

svg

torchlife's People

Contributors

sachinruk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

weihancool

torchlife's Issues

Related paper

Hi,
Thanks for sharing your code. I have one question: do you have a paper that describes this deep learning survival analysis method you implemented?

Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.