Coder Social home page Coder Social logo

akarasman / yolo-heatmaps Goto Github PK

View Code? Open in Web Editor NEW
38.0 1.0 12.0 3.77 MB

A utility for generating heatmaps of YOLOv8 using Layerwise Relevance Propagation (LRP/CRP).

Python 3.16% Jupyter Notebook 96.84%
lrp crp yolo xai computer-vision object-detection heatmap-visualization heatmap explainable-ai explainable-artificial-intelligence

yolo-heatmaps's Introduction

yolo-heatmaps

A utility for generating explanatory heatmaps from YOLOv8 (https://github.com/ultralytics/ultralytics) object detection results using Layerwise Relevance Propagation (LRP/CRP) (https://iphome.hhi.de/samek/pdf/MonXAI19.pdf).

yolo = YOLO('yolov8x.pt') 
detection = yolo(image) # Image is a C x H x W processed tensor

...

lrp = YOLOv8LRP(yolo, power=2, eps=1e-05, device='cuda')

# Explanation is a C x H x W tensor
explanation_lrp_person = lrp.explain(image, cls='person', contrastive=False)
explanation_lrp_cat = lrp.explain(image, cls='cat', contrastive=False)

LRP Heatmaps

image

CRP Heatmaps

image

If you are planning to utilize this repo in your research kindly cite the following work:

@INPROCEEDINGS{9827744,
  author={Karasmanoglou, Apostolos and Antonakakis, Marios and Zervakis, Michalis},
  booktitle={2022 IEEE International Conference on Imaging Systems and Techniques (IST)}, 
  title={Heatmap-based Explanation of YOLOv5 Object Detection with Layer-wise Relevance Propagation}, 
  year={2022},
  volume={},
  number={},
  pages={1-6},
  doi={10.1109/IST55454.2022.9827744}
}

yolo-heatmaps's People

Contributors

akarasman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

yolo-heatmaps's Issues

Kernels relevancy score

Thanks for your implementation. I was wondering if you could please give a hint where can I find or calculate an relevancy value per kernel (or even neuron)?
Thank you in advance for your support.
Best

Please add venv installation

Hi! If try to install requirements with usual venv and pip i saw this:

image

It's hardcoded path in requirements.

It would be cool if you cleanup in requirements and venv will work well.

Expected condition, x and y to be on the same device

Hi, I tried to run the example notebook, however, when I call the explain() function, it says that in file 'lrp/yolo.py' the three arguments for torch.where() function are not on the same device.

RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and CPU respectively

I think the problem is that on line 359:
explanation = torch.where(outlier_mask, lrp_p, torch.tensor(0.0))
the torch.tensor(0.0) should be put on the GPU.

This is how I managed to run the code:

if self.device == "cuda":
  explanation = torch.where(outlier_mask, lrp_p, torch.tensor(0.0).cuda())
else:
  explanation = torch.where(outlier_mask, lrp_p, torch.tensor(0.0))

A similar solution could be implemented in your code as well.

Regardless, thanks for sharing the code, it is helpful.

Paper release

Hi,

this LRP implementation for YOLO is great!
I would really like to use it for my master thesis.
Will you publish a paper regarding your work, so I can cite it properly?

Gradual increase in processing time

While running the code for multiple images/videos, it seems like average processing time gradually increases by quite a lot. For example,
video 1/1 (41/71) <path>: 384x640 1 class_name, Done. (0.818s)
video 1/1 (102/282) <path>: 384x640 class_name, Done. (2.980s)
Could you give any insight about the reason?

Python lib versions

Heyo

I have been having trouble using this repo in a colab environment, I have not been able to run the explain script as there are incompatible modules - this is likely just due to the naming conventions being changed from yolov5 versions.

I am also not able to run the training script, there are a lot of reasons for this - the route cause I attribute to having the wrong requirements installed.

In order to help me can you let me know the verion of yolov5 you have used, and the pytorch/torchvision versions you have used,

Thanks in advance

IndexError in lrp.utils::LayerRelevance.scatter(-1).size(1) caused in lrp.common::prop_C3 using Custom YOLOv5 v6.1/6.2

Hello there,

First off, thanks for your work on this repo!

I am trying to run your explainer on a custom trained YOLOv5 under late v6.1 (pulled from master shortly ahead of v6.2 release). After some tinkering, I got the code framework to run with my setup. The showcase example set by the default parameter values for the explain.py script works just fine.

I am then calling the explain.py script with custom paths for the --weightsand --source flags. All other params are still the default.

Unfortunately, I get the following error message (I am logging relevance ahead of the message):

LayerRelevance(0.0, cache=(17, 0.65351) (20, 0.2363) (23, 0.11019), contrastive=False)

Traceback (most recent call last):
  File "/nn_training/explainer/yolo-heatmaps/explain.py", line 391, in <module>
    main(opt)
  File "/nn_training/explainer/yolo-heatmaps/explain.py", line 387, in main
    run(**vars(opt))
  File "/.conda/envs/YOLOexplainer/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/nn_training/explainer/yolo-heatmaps/explain.py", line 282, in run
    get_explanation(inn_model, init, contrastive, b1, b2, int(cls),
  File "/nn_training/explainer/yolo-heatmaps/explain.py", line 54, in get_explanation
    lrp_out = inn_model.innvestigate(in_tensor=None, initializer=init)
  File "/nn_training/explainer/yolo-heatmaps/innvestigator.py", line 351, in innvestigate
    relevance = self.inverter(layer, relevance)
  File "/nn_training/explainer/yolo-heatmaps/inverter.py", line 206, in __call__
    return self.invert(layer, relevance, **kwargs)
  File "/nn_training/explainer/yolo-heatmaps/inverter.py", line 199, in invert
    return self.inv_funcs[type(layer)](self, layer, relevance, **kwargs) 
  File "/nn_training/explainer/yolo-heatmaps/lrp/common.py", line 98, in prop_C3
    c_ = msg.size(1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

relevance.scatter() returns several cuda tensors, though looking different from the Tensors for the default yolov5s.pt net provided by you. I found that
But then, relevance.scatter(which=-1) returns an empty tensor([]) object.

The explain.py fails after this point. I don't know where to even start fixing the issue at this point.

Can you give me advice as to how to fix this? Am I missing something in the parameter flags maybe?
Do you need more info on the issue?

Thanks in advance and have a wonderful day!
Best, Timm

NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.BottleneckCSP'>

Hi @akarasman ,
I plan to reproduce your code, but when I run the sample command I get a first off error as below. Since I am not very familiar with the yolo model, I would like to ask you for advice.
Error:

$python3 explain.py --source=data/images/me.png --weights=yolov5s.pt --explain-class='person'
explain: weights=['yolov5s.pt'], source=data/images/me.png, imgsz=[640, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, visualize=False, update=False, project=runs/explain, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False, power=1, contrastive=False, b1=1.0, b2=1.0, explain_class=person, conf=False, max_class_only=True, box_xywh=None, smooth_ks=1, box_xyxy=None, cmap=magma
YOLOv5 ๐Ÿš€ 2022-8-1 torch 1.10.1+cu111 CUDA:0 (NVIDIA GeForce RTX 3070, 7971MiB)

Fusing layers... 
Model Summary: 232 layers, 7459581 parameters, 0 gradients
Traceback (most recent call last):
  File "/home/why/Desktop/LRP/inverter.py", line 199, in invert
    return self.inv_funcs[type(layer)](self, layer, relevance, **kwargs) 
KeyError: <class 'models.common.BottleneckCSP'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "explain.py", line 391, in <module>
    main(opt)
  File "explain.py", line 387, in main
    run(**vars(opt))
  File "/home/why/anaconda3/envs/lrp/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "explain.py", line 282, in run
    get_explanation(inn_model, init, contrastive, b1, b2, int(cls),
  File "explain.py", line 54, in get_explanation
    lrp_out = inn_model.innvestigate(in_tensor=None, initializer=init)
  File "/home/why/Desktop/LRP/innvestigator.py", line 351, in innvestigate
    relevance = self.inverter(layer, relevance)
  File "/home/why/Desktop/LRP/inverter.py", line 206, in __call__
    return self.invert(layer, relevance, **kwargs)
  File "/home/why/Desktop/LRP/inverter.py", line 201, in invert
    raise NotImplementedError(f'Relevance propagation not implemented for layer type {type(layer)}')
NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.BottleneckCSP'>

And here's the list of dependencies

$ conda list
# packages in environment at /home/why/anaconda3/envs/lrp:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
absl-py                   2.0.0                     <pip>
asttokens                 2.4.0                     <pip>
backcall                  0.2.0                     <pip>
Bottleneck                1.3.7                     <pip>
ca-certificates           2023.08.22           h06a4308_0    defaults
cachetools                5.3.1                     <pip>
certifi                   2023.7.22                 <pip>
charset-normalizer        3.3.0                     <pip>
contourpy                 1.1.1                     <pip>
cycler                    0.12.1                    <pip>
decorator                 5.1.1                     <pip>
executing                 2.0.0                     <pip>
fonttools                 4.43.1                    <pip>
google-auth               2.23.3                    <pip>
google-auth-oauthlib      1.0.0                     <pip>
grpcio                    1.59.0                    <pip>
idna                      3.4                       <pip>
importlib-metadata        6.8.0                     <pip>
importlib-resources       6.1.0                     <pip>
ipython                   8.12.3                    <pip>
jedi                      0.19.1                    <pip>
joblib                    1.3.2                     <pip>
kiwisolver                1.4.5                     <pip>
ld_impl_linux-64          2.38                 h1181459_1    defaults
libffi                    3.3                  he6710b0_2    defaults
libgcc-ng                 9.1.0                hdf63c60_0    defaults
libstdcxx-ng              9.1.0                hdf63c60_0    defaults
Markdown                  3.5                       <pip>
MarkupSafe                2.1.3                     <pip>
matplotlib                3.7.3                     <pip>
matplotlib-inline         0.1.6                     <pip>
ncurses                   6.3                  h7f8727e_2    defaults
numpy                     1.24.4                    <pip>
oauthlib                  3.2.2                     <pip>
opencv-python             4.8.1.78                  <pip>
openssl                   1.1.1w               h7f8727e_0    defaults
packaging                 23.2                      <pip>
pandas                    2.0.3                     <pip>
parso                     0.8.3                     <pip>
pexpect                   4.8.0                     <pip>
pickleshare               0.7.5                     <pip>
Pillow                    10.1.0                    <pip>
pip                       23.2.1           py38h06a4308_0    defaults
ply                       3.11                      <pip>
prompt-toolkit            3.0.39                    <pip>
protobuf                  4.24.4                    <pip>
psutil                    5.9.6                     <pip>
ptyprocess                0.7.0                     <pip>
pure-eval                 0.2.2                     <pip>
pyasn1                    0.5.0                     <pip>
pyasn1-modules            0.3.0                     <pip>
Pygments                  2.16.1                    <pip>
pyparsing                 3.1.1                     <pip>
python                    3.8.13               h12debd9_0    defaults
python-dateutil           2.8.2                     <pip>
pytz                      2023.3.post1              <pip>
PyYAML                    6.0.1                     <pip>
readline                  8.1.2                h7f8727e_1    defaults
requests                  2.31.0                    <pip>
requests-oauthlib         1.3.1                     <pip>
rsa                       4.9                       <pip>
scikit-learn              1.3.1                     <pip>
scipy                     1.10.1                    <pip>
seaborn                   0.13.0                    <pip>
setuptools                53.0.0                    <pip>
setuptools                68.0.0           py38h06a4308_0    defaults
sip                       6.7.12                    <pip>
six                       1.16.0                    <pip>
sqlite                    3.38.2               hc218d9a_0    defaults
stack-data                0.6.3                     <pip>
tensorboard               2.14.0                    <pip>
tensorboard-data-server   0.7.1                     <pip>
thop-0.1.1                2209072238                <pip>
threadpoolctl             3.2.0                     <pip>
tk                        8.6.11               h1ccaba5_0    defaults
tomli                     2.0.1                     <pip>
torch                     1.10.1+cu111              <pip>
torchaudio                0.10.1+rocm4.1            <pip>
torchvision               0.11.2+cu111              <pip>
tqdm                      4.66.1                    <pip>
traitlets                 5.11.2                    <pip>
typing_extensions         4.8.0                     <pip>
tzdata                    2023.3                    <pip>
urllib3                   2.0.7                     <pip>
wcwidth                   0.2.8                     <pip>
Werkzeug                  3.0.0                     <pip>
wheel                     0.41.2           py38h06a4308_0    defaults
xz                        5.2.5                h7f8727e_1    defaults
zipp                      3.17.0                    <pip>
zlib                      1.2.11                        0    https://mirrors.sjtug.sjtu.edu.cn/anaconda/pkgs/free

Please add venv installation

Hi! If try to install requirements with usual venv and pip i saw this:

image

It's hardcoded path in requirements.

It would be cool if you cleanup in requirements and venv will work well.

Question about prop_C3 and prop_bottleneck

I have a question about your implementation of the back propagation processes you have implemented for the C3 block and the bottleneck block.

From my understanding all the relevance should propagated backwards through the network - that's why in innvestigator.py you have rev_model = self.inverter.module_list[::-1]

However that logic doesn't seem to be strictly followed in the prop_c3 as you call the inverter on each block of the bottleneck in order, moreover, Conv3 seems to be ignored - as indicated below:

def prop_C3(*args):

    inverter, mod, relevance = args
    msg = relevance.scatter(which=-1)

    ## invert mod.cv3 

    c_ = msg.size(1)

    msg_cv1 = msg[:, : (c_ // 2), ...]
    msg_cv2 = msg[:, (c_ // 2) :, ...]

    for m1 in mod.m: <--- RIGHT HERE 
        msg_cv1 = inverter(m1, msg_cv1)
    
    msg = inverter(mod.cv1, msg_cv1) + inverter(mod.cv2, msg_cv2)

    relevance.gather([(-1, msg)])

    return relevance

The same is the case in prop_Bottleneck as you invert cv1, then cv2, also how do you deal with the addition as the forward function is:

y = x + conv2( conv1 ( x ))

def prop_Bottleneck(*args):

    inverter, mod, relevance_in = args

    ar = mod.cv2.conv.out_tensor.abs()
    ax = mod.cv1.conv.in_tensor.abs()

    relevance = relevance_in
    relevance = inverter(mod.cv1, relevance)
    relevance = inverter(mod.cv2, relevance)

    return relevance

My question is why is this the case, is it taken care of in another part of the code?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.