Coder Social home page Coder Social logo

m-nauta / prototree Goto Github PK

View Code? Open in Web Editor NEW
87.0 1.0 17.0 891 KB

ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition, published at CVPR2021

License: MIT License

Python 100.00%
pytorch computer-vision explainable-ai interpretability interpretable-machine-learning deep-neural-networks explainable-ml fine-grained-visual-categorization explainability interpretable-deep-learning

prototree's Introduction

ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition

This repository presents the PyTorch code for Neural Prototype Trees (ProtoTrees), published at CVPR 2021: "Neural Prototype Trees for Interpretable Fine-grained Image Recognition".

A ProtoTree is an intrinsically interpretable deep learning method for fine-grained image recognition. It includes prototypes in an interpretable decision tree to faithfully visualize the entire model. Each node in our binary tree contains a trainable prototypical part. The presence or absence of this prototype in an image determines the routing through a node. Decision making is therefore similar to human reasoning: Does the bird have a red throat? And an elongated beak? Then it's a hummingbird!

Example of a ProtoTree. Figure shows an example of a ProtoTree. A ProtoTree is a globally interpretable model faithfully explaining its entire behaviour (left, partially shown) and additionally the reasoning process for a single prediction can be followed (right): the presence of a red chest and black wing, and the absence of a black stripe near the eye, identifies a Scarlet Tanager.

Prerequisites

General

  • Python 3
  • PyTorch >= 1.5 and <= 1.7!
  • Optional: CUDA

Required Python Packages:

  • numpy
  • pandas
  • opencv
  • tqdm
  • scipy
  • matplotlib
  • requests (to download the CARS dataset, or download it manually)
  • gdown (to download the CUB dataset, or download it manually)

Data

The code can be applied to the CUB-200-2011 dataset with 200 bird species, or the Stanford Cars dataset with 196 car types.

The folder preprocess_data contains python code to download, extract and preprocess these datasets.

Preprocessing CUB

  1. create a folder ./data/CUB_200_2011
  2. download ResNet50 pretrained on iNaturalist2017 (Filename on Google Drive: BBN.iNaturalist2017.res50.180epoch.best_model.pth) and place it in the folder features/state_dicts.
  3. from the main ProtoTree folder, run python preprocess_data/download_birds.py
  4. from the main ProtoTree folder, run python preprocess_data/cub.py to create training and test sets

Preprocessing CARS

  1. create a folder ./data/cars
  2. from the main ProtoTree folder, run python preprocess_data/download_cars.py
  3. from the main ProtoTree folder, run python preprocess_data/cars.py to create training and test sets

Training a ProtoTree

  1. create a folder ./runs

A ProtoTree can be trained by running main_tree.py with arguments. An example for CUB: main_tree.py --epochs 100 --log_dir ./runs/protoree_cub --dataset CUB-200-2011 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --num_features 256 --depth 9 --net resnet50_inat --freeze_epochs 30 --milestones 60,70,80,90,100 To speed up the training process, the number of workers of the DataLoaders can be increased by setting num_workers to a positive integer value (suitable number depends on your available memory).

Check your --log_dir to keep track of the training progress. This directory contains log_epoch_overview.csv which prints per epoch the test accuracy, mean training accuracy and the mean loss. File log_train_epochs_losses.csv prints the loss value and training accuracy per batch iteration. File log.txt logs additional info.

The resulting visualized prototree (i.e. global explanation) is saved as a pdf in your --log_dir /pruned_and_projected/treevis.pdf. NOTE: this pdf can get large which is not supported by Adobe Acrobat Reader. Open it with e.g. Google Chrome or Apple Preview.

To train and evaluate an ensemble of ProtoTrees, run main_ensemble.py with the same arguments as for main_tree.py, but include the --nr_trees_ensemble to indicate the number of trees in the ensemble.

Local explanations

A trained ProtoTree is intrinsically interpretable and globally explainable. It can also locally explain a prediction. Run e.g. the following command to explain a single test image:

main_explain_local.py --log_dir ./runs/protoree_cars --dataset CARS --sample_dir ./data/cars/dataset/test/Dodge_Sprinter_Cargo_Van_2009/04003.jpg --prototree ./runs/protoree_cars/checkpoints/pruned_and_projected

In the folder --log_dir /local_explanations, the visualized local explanation is saved in predvis.pdf.

prototree's People

Contributors

m-nauta avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

prototree's Issues

Model can not be loaded

First of all, thanks a lot for your interesting work! I recreated the training process very well. However, there seems to be a problem loading the model. When I use a new file to load the model (use tree.load() or tree.load_state_dict()), it seems that the loaded tree cannot do well on the original training dataset. Did I do something wrong, could you help me fix this problem?
load_code

result

Upsample issue multiple bounding boxes

Hi Meike,

Since I am not a collaborator on this repository, I am unable to push any changes. The problem for upsample.py where too large bounding boxes are drawn for a prototype image can be fixed by replacing lines 56-58 with the following code:
# save the highly activated patch
masked_similarity_map = np.zeros(similarity_map.shape)
prototype_index = prototype_info['patch_ix']
masked_similarity_map[prototype_index // 7, prototype_index % 7] = 1

Instead of finding the maximum value of the latent heatmap, the project_info is used for finding the location of the patch/prototype. The problem with the previous code was that the heatmap can contain multiple maxima.

Kind regards,
Guido

Model inference question

Hi, thanks for this wonderful work, I was wondering if you could provide a tutorial on how to use trained model to do inference with model.pth, model_state.pth and tree.pkl.

Greatly appreciate the help

Projection of prototypes

Hi thanks for the great repo, it was really easy to reproduce.

I've just a general question about the projection step, you seem to use the bounding boxes to do projection right? Like you crop the images to just the box around each bird as the data for the loader used for projection.

Why is that? Does it not work as well with the full images? Or is there a technical reason? Thanks in advance.

Training time

First of all, congrats on the very interesting paper!

I was wondering what architecture you used to train the model, and how much time it took. I'm trying to train on Colab with a GPU and it's taking ~40s per iteration (or ~5h per epoch). Is that the expected training time?

My final acc is only 76%

Eval Epoch 95: 100% 91/91 [01:49<00:00,  1.21s/it, Batch [91/91], Acc: 0.559]
Train Epoch 96: 100% 469/469 [05:52<00:00,  1.33it/s, Batch [469/469], Loss: 0.3
Eval Epoch 96: 100% 91/91 [01:51<00:00,  1.22s/it, Batch [91/91], Acc: 0.618]
Train Epoch 97: 100% 469/469 [05:55<00:00,  1.32it/s, Batch [469/469], Loss: 0.2
Eval Epoch 97: 100% 91/91 [01:47<00:00,  1.18s/it, Batch [91/91], Acc: 0.529]
Train Epoch 98: 100% 469/469 [05:52<00:00,  1.33it/s, Batch [469/469], Loss: 0.3
Eval Epoch 98: 100% 91/91 [01:48<00:00,  1.19s/it, Batch [91/91], Acc: 0.588]
Train Epoch 99: 100% 469/469 [05:56<00:00,  1.32it/s, Batch [469/469], Loss: 0.4
Eval Epoch 99: 100% 91/91 [01:46<00:00,  1.17s/it, Batch [91/91], Acc: 0.618]
Train Epoch 100: 100% 469/469 [05:53<00:00,  1.33it/s, Batch [469/469], Loss: 0.
Eval Epoch 100: 100% 91/91 [01:47<00:00,  1.18s/it, Batch [91/91], Acc: 0.647]
Eval Epoch pruned: 100% 91/91 [01:39<00:00,  1.10s/it, Batch [91/91], Acc: 0.647
Projection: 100% 375/375 [02:51<00:00,  2.19it/s, Batch: 375/375]
Eval Epoch pruned_and_projected: 100% 91/91 [01:36<00:00,  1.06s/it, Batch [91/9
Eval Epoch pruned_and_projected: 100% 91/91 [01:33<00:00,  1.03s/it, Batch [91/9
Eval Epoch pruned_and_projected: 100% 91/91 [02:16<00:00,  1.50s/it, Batch [91/9
Fidelity: 100% 91/91 [02:47<00:00,  1.84s/it, Batch [91/91]]
```bash
In the overview table
| 85   | 0.764411 | 0.927753346 | 0.315042642 |
| ---- | -------- | ----------- | ----------- |
| 86   | 0.766483 | 0.929271055 | 0.312246734 |
| 87   | 0.768208 | 0.929870736 | 0.310522149 |
| 88   | 0.77028  | 0.928586235 | 0.310088791 |
| 89   | 0.76631  | 0.929519071 | 0.306573199 |
| 90   | 0.763721 | 0.929437633 | 0.307344964 |
| 91   | 0.764066 | 0.929333985 | 0.301229446 |
| 92   | 0.765447 | 0.929681947 | 0.299217476 |
| 93   | 0.765274 | 0.930585169 | 0.298738088 |
| 94   | 0.768381 | 0.929467247 | 0.298132296 |
| 95   | 0.765965 | 0.929104478 | 0.296845926 |
| 96   | 0.76631  | 0.929918858 | 0.295910276 |
| 97   | 0.764066 | 0.930070629 | 0.295259658 |
| 98   | 0.763548 | 0.92906746  | 0.296624792 |
| 99   | 0.764411 | 0.929082267 | 0.296345691 |
| 100  | 0.761823 | 0.929915156 | 0.294409203 |

Accuracy of Prototree ensemble 5

Thanks for your excellent work!
When i reproduce the Prototree ensemble 5 result, my accuracy is 83, lower than 87.2 in article. Here is my run command: python main_ensemble.py --epochs 100 --log_dir ./runs/protoree_cub --dataset CUB-200-2011 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --num_features 256 --depth 9 --net resnet50_inat --freeze_epochs 30 --milestones 60,70,80,90,100 --nr_trees_ensemble 5
How can I solve that?

Using proto tree for multi label classification

Thank you for your interesting work.
I was wondering if this work can be applied in multi label classification setting, say in the domain of medical image analysis.
If not, can you let me what are the short comings that stop it from being applied there?

connect to swin transformer

Hi, thanks for this wonderful work, I was wondering if the author could provide for code for using swin transformer as feature extractor, thanks, deeply appericiate the support.

is finetuning avaliable?

Hi, thanks for this awesome paper, I was just wondering if finetuning neuro prototree is possible?

Acc lower than paper claims

Hi @M-Nauta , thank you for your very interesting paper and this very well-written repo. However, when I follow README to reproduce the results on CUB_200_2011, I got 72% and 78% accuracy in two runs. Did I miss anything? How should I fix this?

Accuracy of non-iNat Networks

Hi,

I was able to recreate the accuracy of ResNet50 pre-trained on iNat using the suggestions I found in the closed issues (i.e., around 82%), but when I substitute the network for another (say VGG-16) I get an accuracy of around 11%. Or ResNet50 pre-trained on ImageNet gets around 62%, and ResNet18 gets around 30%.

I'm just wondering if that's normal? Or are the other hyperparameters which need to be changed to boost accuracy on other networks? Thank you if you have time to offer suggestions.

So e.g., I may use this

python main_tree.py --epochs 150 --log_dir ./runs/protoree_cub --dataset CUB-200-2011 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --num_features 256 --depth 9 --net vgg16 --freeze_epochs 30 --milestones 60,80,100,120,140

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.