Coder Social home page Coder Social logo

fursovia / tcav_nlp Goto Github PK

View Code? Open in Web Editor NEW
5.0 6.0 3.0 1.07 MB

"Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)" paper implementation

Home Page: https://arxiv.org/abs/1711.11279

Python 1.80% Jupyter Notebook 98.20%
nlp interpretability tensorflow deep-learning

tcav_nlp's Introduction

Quantitative Testing with Concept Activation Vectors (TCAV) in NLP

Data preparation

We use Lenta.Ru dataset in our experiments.

  1. Create data_path folder and put lenta-ru-news.csv file in it
  2. Choose labels to experiment with and split the data into train and test by running lenta_dataset.ipynb
  3. Finally, run python data_prep.py -dd data_path. This command will save full.csv, train.csv, eval.csv and vocabs.txt files to data_path

Training

  1. Modify build_model function if you want to change an architecture of the model
  2. Create experiment_path folder and put experiments/config.yaml in it
  3. Modify hyperparameters inside experiment_path/config.yaml
  4. Run python train.py -dd data_path -md experiment_path. This command will train the model and save checkpoints to experiment_path

Create concepts

  1. Choose words you would like to experiments with. For example, Москва, ООН, Жириновский will be a good choice.
  2. Run python collect_concepts.py -dd data_path -md experiment_path --ngrams 3. This command will generate multiple files:
  • concepts.pkl -- for each concept (e.g. Москва) we search for sentences where this word occurs. Then we retrieve ngrams of size n from this sentence (e.g. лето в Москва, Москва слезам не верит) and call it concepts. Also we collect some random samples from the data for each concept.
  • cav_bottlenecks.pkl -- we convert concept texts into hidden representations of the model from experiment_path folder
  • cavs.pkl -- Hyperplanes for each concept received by fitting Logistic Regression on concept/non-concept data. LR is trained on hidden representation of the data.
  • grads.pkl -- directional derivatives (see the paper for more details)

Calculate TCAV scores

  1. Run python calculate_tcav.py -dd data_path. This command will save scores.pkl file. In this file you can find TCAV scores for each concept against all labels.

Plot graphs

Run TCAV.ipynb to compare results.

tcav_nlp's People

Contributors

fursovia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

tcav_nlp's Issues

Integrating pre-trained embeddings & Using input_fn inside collect_concepts

Hi Ivan, first of all thank you for sharing your code.
I am starting to run a few experiments and wanted to integrate pre-trained embeddings.

So far I was thinking about adding it into the model_fn and use the nn.embedding_lookup.
Alternatively one could also integrate it directly inside the vectorize function inside the input_fn.py

The following is the snippet I am using to load in embeddings.

def loadPretrainedModel(file):
    print("Loading Pretrained Model")
    f = open(file,'r')
    model = {}
    for line in f:
        splitLine = line.split()
        word = splitLine[0]
        embedding = np.array([float(val) for val in splitLine[1:]])
        model[word] = embedding
    print("Done.",len(model)," words loaded!")
    return model

I was considering something like the following:

word_model = loadPretrainedModel(file):
variable = tf.Variable(word_model, dtype=tf.float32, trainable=False)
embeddings = tf.nn.embedding_lookup(variable, features['x'], name='emb_matrix_lookup')

Somehow this is not really working though.

Alternatively one could integrate the gensim library which could speed up things perhaps.
I would really appreciate any thoughts you have on this.

––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––

In addition to this, I wanted to ask whether you think it is possible to use the input_fn.py mapping for getting the gradients inside the collect_concepts.py too.

Currently the code is using

grads = mw.calculate_grad(sess, labels, train['text'].tolist())

However, the issue is that this doesn't work with input sequences longer than 10 as it doesn't automatically create the sequences.
One way to get around is to transform the whole data into sequences that follow seq length but ideally one could use the same mapping from input_fn used for training the model. If you have any ideas how this could work with TensorFlow’s dataloader/mapping that would be great.

In any case: thanks again for sharing and I would apreciate any help you could provide.

labs_mapping.pkl unknown

In the file calculate_tcav.py lines 17-18 you load a pickle file called labs_mapping.pkl. I do not see this created anywhere in the code. Could you help me understand where this comes from?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.