Coder Social home page Coder Social logo

bex's Introduction

framework

A Benchmark For Counterfactual Explainers

License

Explainability methods have been widely used to provide insight into the decisions made by statistical models, thus facilitating their adoption in various domains within the industry. Counterfactual explanation methods aim to improve our understanding of a model by perturbing samples in a way that would alter its response in an unexpected manner. This information is helpful for users and for machine learning practitioners to understand and improve their models. Given the value provided by counterfactual explanations, there is a growing interest in the research community to investigate and propose new methods. However, we identify two issues that could hinder the progress in this field. (1) Existing metrics do not accurately reflect the value of an explainability method for the users. (2) Comparisons between methods are usually performed with datasets like CelebA, where images are annotated with attributes that do not fully describe them and with subjective attributes such as "Attractive". In this work, we address these problems by proposing an evaluation method with a principled metric to evaluate and compare different counterfactual explanation methods. The evaluation is based on a synthetic dataset where images are fully described by their annotated attributes. As a result, we are able to perform a fair comparison of multiple explainability methods in the recent literature, obtaining insights about their performance.

[paper]

Description

Code repository for the Bex explainability benchmark. Models and datasets that comprise the benchmark can be found here. They are automatically downloaded when the benchmark is ran.

The dataset used for the benchmark is a modified version of Synbols that contains black and white characters with various attributest that define them (e.g., font, rotaion, scale, etc)

samples

Installation

The recommended way to install Bex is via PyPI

pip install bex

Usage

For more information about the usage check out the Documentation

We provide a set of counterfactuals explainers already implemented in the benchmark:

  1. Beyond Trivial Counterfactual Explanations with Diverse Valuable Explanations (DiVE)
  2. xGEMs: Generating Examplars to Explain Black-Box Models (xGEM)
  3. Latent-CF: A Simple Baseline for Reverse Counterfactual Explanations (Latent-CF)
  4. Explaining in Style: Training a GAN to explain a classifier in StyleSpace (StylEx)
  5. Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations (DiCE)
  6. Inverse Classification for Comparison-based Interpretability in Machine Learning (Growing Spheres)
  7. An oracle with access to the correlated and causal attributes (IS)

The benchmark includes different setting by modifying the number of correlated n_corr attributes and their level of correlation corr_level. Right now there are 4 settings available:

  • n_corr=10, corr_level=0.95 (default)
  • n_corr=6, corr_level=0.95
  • n_corr=10, corr_level=0.5
  • n_corr=6, corr_level=0.5

Evaluating one of the predefined explainers

import bex
bn = bex.Benchmark(n_corr=6, corr_level=0.95) # downloads necessary files
bn.run("stylex") # or any of: "dive", "xgem", "lcf", "dice", "gs", "IS" (Oracle)
bn.run("IS", output_path="output/is")
print(bn.summarize()) # get the performance of each explainer as a pandas dataframe

You can reproduce the experiments in the paper by runnning python run_benchmark.py

Evaluate a custom explainer

You can evaluate your own explainer like so:

import random
import bex

class DummyExplainer(bex.explainers.ExplainerBase):

    def __init__(self, num_explanations):
        super().__init__()
        self.num_explanations = num_explanations

    # This function describes the behaviour of the custom explainer for a given batch
    def explain_batch(self, latents: torch.Tensor, logits: torch.Tensor,
                      images: torch.Tensor, classifier: torch.nn.Module,
                      generator: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:

        b, c = latents.size()
        # we will produce self.num_explanations counterfactuals per sample
        z = latents[:, None, :].repeat(1, self.num_explanations, 1)
        z_perturbed = z + random.random() # create counterfactuals z'

        return z_perturbed.view(b, self.num_explanations, c)

bn = bex.Benchmark()
bn.run(DummyExplainer, num_explanations=10)
print(bn.summarize()) # get the explainer's performance

Logging

We provide a basic logger to log results and image samples; it is activated by default but you can deactivate it by setting it to None:

bn = bex.Benchmark(n_corr=6, corr_level=0.95, logger=None)
bn.run("stylex") # nothing will be logged

# a pandas dataframe holding the results of all .run() calls can always be obtained by calling
bn.summarize()

It is also possible to use a custom logger. For example, here is a custom logger using weights and biases:

from bex.loggers import BasicLogger

class WandbLogger(BasicLogger):

    '''
    Args:
        attributes (``Dict``): dictionary containing the run config
        path: (``string``): output path for the logger
        n_batches: (``int``, optional): max number of image batches to log

    '''

    def __init__(self, attributes, path, n_batches=10):

        super().__init__(attributes, path, n_batches)

        wandb.init(project="Synbols-benchmark", dir=self.path, config=self.attributes, reinit=True)


    # accumulate metrics for this step
    def accumulate(self, data, images):

        super().accumulate(data, images)

        wandb.log({f"{k}" :v for k, v in data.items()}, commit=True)


    # log average value of all the metrics across steps
    def log(self):

        self.metrics = {f"{k}_avg": np.mean(v) for k, v in self.metrics.items()}
        wandb.log(self.metrics)

        # create matplotlib figure with the counterfactuals generated
        fig = self.create_cf_figure()

        wandb.log({"Counterfactuals": fig})

        plt.close()

bn = bex.Benchmark(n_corr=6, corr_level=0.95, logger=WandbLogger)
bn.run("IS") # results will be logged to weights and biases

print(bn.summarize()) # results stored in memory

Citation

If you find this work useful, please consider citing the following paper:

Diego Velazquez, Pau Rodriguez, Alexandre Lacoste, Issam H. Laradji, Xavier Roca, and Jordi Gonzàlez. 2023. Explaining Visual Counterfactual Explainers. Transactions on Machine Learning Research. ISSN: 2835-8856.

@article{
velazquez2023explaining,
title={Explaining Visual Counterfactual Explainers},
author={Diego Velazquez and Pau Rodriguez and Alexandre Lacoste and Issam H. Laradji and Xavier Roca and Jordi Gonz{\`a}lez},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=RYeRNwRjNE},
note={Reproducibility Certification}
}

Contact

For any bug or feature requests, please create an issue.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.