Coder Social home page Coder Social logo

interprettime's Introduction

Evaluation of post-hoc interpretability methods in time-series classification

Paper pdm Paper

This repository is the implementation code for : Evaluation of post-hoc interpretability methods in time-series classification.

If you find this code or idea useful, please consider citing our work:

Turbé, H., Bjelogrlic, M., Lovis, C. et al. 
Evaluation of post-hoc interpretability methods in time-series classification. 
Nat Mach Intell 5, 250–260 (2023). https://doi.org/10.1038/s42256-023-00620-w

Table of Contents

Overview

This repository presents the framework used in the paper "Evaluation of post-hoc interpretability methods in time-series classification". The framework allows evaluating and ranking of interpretability methods' performance via the AUCSE metric; for time series classification. As depicted in Fig.1, different interpretability methods can produce tangibly different results for the same model. It therefore becomes key to understand which intepretability method better reflect what the model learned and how the latter is using the input features. The paper also introduces a new synthetic dataset with tunable complexity that can be used to assess the performance of interpretability methods, and that is able to reproduce time-series classification tasks of arbitrary complexity.

Fig.1: Relevance produced by four post-hoc interpretability methods, obtained on a time-series classification task, where a Transformer neural network needs to identify the pathology of a patient from ECG data. Depicted in black are two signals, V1, and V2, while the contour map represents the relevance produced by the interpretability method. Red indicates positive relevance, while blue indicate negative relevance. The former marks portions of the time series that were deemed important by the interpretability method for the neural-network prediction. The latter marks portions of the time series that were going against the prediction.

Data

Data associated with this repository as well as results presented in the article can be found on Zenodo and are organised in 3 zip files. The folders should be unzipped and copied into data/.

data
├── datasets
├── trained_models
└── results_paper

Datasets

Three datasets were used in the article:

Each dataset is provided in Parquet format compatible with Petastorm with each folder organised as follows:

Datasets
    ├── dataParquet
    ├── config__***.json
    ├── train.npy
    ├── val.npy
    └── test.npy
  • dataParquet: folder with the data formatted with Petastorm

  • config__***.json: configuration file used to process the data

  • train.npy: numpy array with the names of the samples used for the training set

  • val.npy: numpy array with the names of the samples used for the validation set

  • test.npy: numpy array with the names of the samples used for the test set

Trained Models

Trained models are provided in the trained_models folder. A transformer, bi-lstm and CNN model are trained for three different datasets respectively called: ecg, fordA and synthetic. Each folder contains the following files:

Simulation
    ├── results
    ├── classes_encoder.npy
    ├── config__***.yaml
    ├── stats_scaler.csv
    └── best_model.ckpt
  • results: folder with the classification results of the simulation

  • classes_encoder.npy: class used for the encoder

  • config__***.yaml: config file with model's hyperparameters

  • stats_scaler.csv: Mean and median of the samples included in the training set (use to normalise the data)

  • best_model.ckpt: saved model

Paper's results

The results obtained in the paper can be found in the results_paper folder. This folder includes the relevance obtained across the different models trained as well as the evaluation metrics. The results are organised as follows:

results_paper
    ├── model_interpretability
        └── simulation
             ├── interpretability_raw
             └── interpretability_results
    └── summary_results
  • interpretability_raw: include the relevance computed using the different interpretability methods for each sample which was included in the analysis

  • interpretability_results: include a summary of the different metrics for each interpretability metrics as well as a summary across the different methods included in metrics_methods.csv

Usage

Setup tested with python 3.8.16 and Linux. Building the container described below and installing the python requirements can take up to 10 minutes the first time.

Devcontainer

A configuration for a development container is provided in the .devcontainer folder. The current configuration assumes a CUDA-compatible gpu is available. In case gpus are not available, the following lines in .devcontainer/devcontainer.json should be commented

    "hostRequirements": {
        "gpu": true
    },
    "features": {
        "ghcr.io/devcontainers/features/nvidia-cuda:1": {}
    },
    "runArgs": [
        "--gpus",
        "all"
    ],

Requirements

The python packages are managed with PDM. If the project is not run using the .devcontainer provided, pdm should first be installed (see instructions) before installing the packages with the following command:

pdm install

Synthetic dataset creation

The synthetic dataset presented in the article can be recreated using the script in src/operations/__exec_generate_data.py. The command takes as input a .yaml file to configure the parameters used to generate the dataset. An example can be found in environment/config_generate_synthetic.yaml.

e.g. from within src/operations:

python3 __exec_generate_data.py --config_file ../../environment/config_generate_synthetic.yaml

A notebook to visualise the samples from the synthetic dataset is included in /src/notebooks/visualise_sample.ipynb.

Post-processing

Relevance and metrics presented in the paper can be computed using the following command from within the src/operations folder:

python3 __exec_postprocessing.py --results_path --method_relevance
  • results_path: path to the folder with the trained model

  • sample_file: name of the file stored in src/assets/sample_post. This file contains the name of the samples on which the relevance as well as the metrics are computed on.

  • method_relevance: list of interpretability method to be evaluated. Can be one (or a list) of [shapleyvalue, integrated_gradients, deeplift, gradshap, saliency, kernelshap]

Interpretability evaluation metrics are saved in results/name_simulation

Example for synthetic dataset

From within src/operations:

python3 __exec_postprocessing.py --model_path ../../data/trained_models/synthetic_cnn --sample_file sample_synthetic.npy

Example for ECG dataset

From within src/operations:

python3 __exec_postprocessing.py --model_path ../../data/trained_models/ecg_cnn --sample_file sample_ecg.npy

Example for FordA dataset

From within src/operations:

python3 __exec_postprocessing.py --model_path ../../data/trained_models/forda_cnn --sample_file sample_forda.npy

interprettime's People

Contributors

hturbe avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.