Coder Social home page Coder Social logo

luanademi / toumei Goto Github PK

View Code? Open in Web Editor NEW
10.0 3.0 2.0 14.67 MB

An interpretability library for pytorch

Home Page: https://luanademi.github.io/toumei/

License: GNU General Public License v3.0

Python 100.00%
interpretability python pytorch feature-visualization deep-learning transformer modularity ai-safety

toumei's Introduction


Logo

toumei (透明)

An interpretability library for pytorch
Explore the docs »

View Examples · Report Bug · Request Feature

Table of Contents
  1. About The Project
  2. Getting Started
  3. Usage
  4. Contributing
  5. License
  6. Contact
  7. References

About The Project

This is in active development so the README might be outdated and is not listing all things currently implemented. See the dev branch for more information.

toumei is a little sideproject of mine, trying to combine state of the art interpretability and model editing methods into a pythonic library. The goal is to compile useful methods into a coherent toolchain and make complexe methods accessible using a intuitive syntax.

I think interpretability methods became quite powerful and therefore useful in the last couple years, wanting me to provide a library for broader use of these methods.

Following methods are currently or will be implemented:

I am planning to add new things as I learn about them in the future, so this project basically mirrors my progress in the field of AI Interpretability.

Built With

PyTorch scikit-learn NetworkX Seaborn tqdm NumPy Huggingface

Getting Started

toumei can not be installed using pip. To use toumei by running the experiments or adding it to your projects, please follow the guide below.

Prerequisites

Make sure the following libraries are installed or install them using

pip install torch torchvision tqdm matplotlib transformers seaborn scikit-learn networkx

Installation

  1. Clone the repo
    git clone https://github.com/LuanAdemi/toumei.git
  2. Run the experiments
    cd toumei/experiments
    python <experiment>.py
  3. Move the library to your project
    cd ..
    cp toumei <path_to_your_project>

(back to top)

Usage

Simple feature visualization

In order to perform feature visualization on a convolutional model we are going to need two things: a image parameterization method and an objective.

These are located in the toumei.cnn package.

import torch
import torchvision.transforms as T

# import toumei
import toumei.cnns.objectives as obj
import toumei.cnns.parameterization as param

Next, we are going to import a model we can perform feature visualization on

from toumei.models import Inception5h

# the model we want to analyze
model = Inception5h(pretrained=True)

To counter noise in the optimization process, we are going to define a transfrom function used to perfom transformation robustness regularization

# compose the image transformation for regularization through transformations robustness
transform = T.Compose([
    T.Pad(12),
    T.RandomRotation((-10, 11)),
    T.Lambda(lambda x: x*255 - 117)  # inception needs this
])

We are now able to define our objective pipeline using a image parameterization method (here FFT) and our objective (visualize unit mixed3a:74)

# define a feature visualization pipeline
fv = obj.Pipeline(
    # the image generator object
    param.Transform(param.FFTImage(1, 3, 224, 224), transform),

    # the objective function
    obj.Channel("mixed3a:74")
)

Finally, we are going to optimize our pipeline and plot the results

# attach the pipeline to the model
fv.attach(model)

# send the objective to the gpu
fv.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# optimize the objective
fv.optimize()

# plot the results
fv.plot()

Causal Tracing

If we want to locate factual knowledge in GPT like models, we can use causal tracing. toumei implements this in the toumei.transformers.rome package.

from toumei.transformers.rome.tracing import CausalTracer

This will import everything we need to perform causal tracing. Using huggingfaces transformers library we can easily get a model we can perform causal tracing on

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# load gpt2 from huggingface
model = AutoModelForCausalLM.from_pretrained("gpt2-xl", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")

model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

After defining a prompt and specifying the subject, we can create a CausalTracer object and trace the model using the prompt

# specify a prompt and it's subject for causal tracing
prompt = "Karlsruhe Institute of Technology is located in the country of"
subject = "Karlsruhe Institute of Technology"

# perform causal tracing
tracer = CausalTracer(model, tokenizer)
tracer.trace(prompt, subject, verbose=True)

Measuring Model Modularity

toumei implements the modularity metric derived from graph theory for common deep learning architectures such as MLPs and CNNs.

We start by converting our model to an actual graph, we can perform graph algorithms on. toumei provides different wrappers for every architecture, which can be imported from the misc package

from toumei.misc import MLPGraph, CNNGraph

Next we import and initialize some models, of which we want to measure the modularity of.

from toumei.models import SimpleMLP, SimpleCNN

# create the models
mlp = SimpleMLP(4, 4)
cnn = SimpleCNN(1, 10)

Wrapping these models with the imported classes builds the corresponding weighted graph of the model

# create graph from the model
mlp_graph = MLPGraph(mlp)
cnn_graph = CNNGraph(cnn)

This wrapper allows us to perform all sorts of graph algorithms on it. We can get the modularity of the graph by performing spectral clustering on it to partition the graph in $n$ communities we can use to calculate the graph modularity.

This is all done internally by calling

# calculate the modularity
print(mlp_graph.get_model_modularity())
print(cnn_graph.get_model_modularity())

Other

See the experiments folder for more examples

(back to top)

Contributing

You are more than welcome to contribute to this project or propose new interpretability methods I can add. Just open an issue or pull request, like you would do on any other github repo.

(back to top)

License

Distributed under the GPL-3.0 License. See LICENSE.txt for more information.

(back to top)

Contact

Luan Ademi - [email protected]

Project Link: https://github.com/LuanAdemi/toumei

(back to top)

References

The following section lists resources I recommend / used myself for building this project.

What is interpretability and why should I care?

Feature Visualization and Circuit-based interpretability

Unified Feature Attribution

Rank-One model editing

Modularity

(back to top)

toumei's People

Contributors

chickenprop avatar davanchama avatar luanademi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.