Coder Social home page Coder Social logo

avani17101 / cd Goto Github PK

View Code? Open in Web Editor NEW
5.0 2.0 0.0 4.09 MB

Code for paper "Concept Distillation: Leveraging Human-Centered Explanations for Model Improvement​", Neurips 2023

License: Other

Python 87.61% Shell 0.17% Jupyter Notebook 12.22%
concepts debiasing faces graphics iid interpretable-machine-learning mnist neurips-2023 responsible-ai vision

cd's Introduction

Concept Distillation

Code for our paper titled "Concept Distillation: Leveraging Human-Centered Explanations for Model Improvement​" accepted at Neurips, 2023.

Abstract

Humans use abstract concepts for understanding instead of hard features. Recent interpretability research has focused on human-centered concept explanations of neural networks. Concept Activation Vectors (CAVs) estimate a model's sensitivity and possible biases to a given concept. In this paper, we extend CAVs from post-hoc analysis to ante-hoc training in order to reduce model bias through fine-tuning using an additional Concept Loss. Concepts were defined on the final layer of the network in the past. We generalize it to intermediate layers using class prototypes. This facilitates class learning in the last convolution layer which is known to be most informative. We also introduce Concept Distillation to create richer concepts using a pre-trained knowledgeable model as the teacher. Our method can sensitize or desensitize a model towards concepts. We show applications of concept-sensitive training to debias several classification problems. We also use concepts to induce prior knowledge into IID, a reconstruction problem. Concept-sensitive training can improve model interpretability, reduce biases, and induce prior knowledge. P

Proposed Use of Teacher

Website | Neurips Link

Method

Pipeline

Our framework comprises a concept teacher and a student classifier and has the following four steps:

  1. Mapping teacher space to student space for concepts $C$ and $C'$ by training an autoencoder $E_M$ and $D_M$ (dotted purple lines).
  2. CAV $v^l_c$ learning in mapped teacher space via a linear classifier LC (dashed blue lines).
  3. Training the student model with Concept Distillation (solid orange lines): We use $v^l_c$ and class prototypes loss $L_p$ to define our concept distillation loss $L_c$ and use it with the original training loss $L_o$ to (de)sensitize the model for concept $C$.
  4. Testing where the trained model is applied (dotted red lines).

Inputs

  • We need user provided concepts sets, train, eval and test datasets and a model to be trained as inputs.

Instructions to run the Code

  • Install requirements by
conda env create -f environment.yml
or 
pip install -r requirements.txt
  • For replicating our results on MNIST, we provide data creation scripts in src/mnist.

  • We have concept sets in concepts folder.

  • Do the four mentioned steps for Concept Distillation. Check scripts/run_all_components.sh for details.

  • We provide an additional script CD_modular.py for you to customize according to dataset, models and task at hand (with user given concept sets)

TODOs

  • Clean codebase.
  • Add concept examples for example code run.

Acknowledgements

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

cd's People

Contributors

avani17101 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cd's Issues

Difficulty Understanding Code utils_train.py (Lines 146-424)

I'm encountering difficulty understanding the meaning of certain terms used within the get_pairs function located in utils_train.py. These terms appear within dictionaries defining concept pairs, but their purpose and relationship to the concepts themselves are unclear.

Examples of Unclear Terminology:

  • randoms
  • randoms20
  • 7_texture_align
  • biased_color1
    ... (and others)

Request for Clarification:

Could you provide a description or definition for some of these unclear terms?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.