Coder Social home page Coder Social logo

grad-cam-pytorch's Introduction

Grad-CAM with PyTorch

PyTorch implementation of Grad-CAM (Gradient-weighted Class Activation Mapping) [1]. Grad-CAM localizes and highlights discriminative regions that a convolutional neural network-based model activates to predict visual concepts. This repository only supports image classification models.

Dependencies

  • Python 2.7+/3.6+
  • PyTorch 0.4.1+
  • torchvision 0.2.2+
  • click
  • opencv
  • tqdm

Basic usage

python main.py demo1 [OPTIONS]

Options:

  • -i, --image-paths: image path, which can be provided multiple times (required)
  • -a, --arch: a model name from torchvision.models, e.g. "resnet152" (required)
  • -t, --target-layer: a module name to be visualized, e.g. "layer4.2" (required)
  • -k, --topk: the number of classes to generate (default: 3)
  • -o, --output-dir: a directory to store results (default: ./results)
  • --cuda/--cpu: GPU or CPU

The command above generates, for top k classes:

  • Gradients by vanilla backpropagation
  • Gradients by guided backpropagation [2]
  • Gradients by deconvnet [2]
  • Grad-CAM [1]
  • Guided Grad-CAM [1]

The guided-* do not support F.relu but only nn.ReLU in this codes. For instance, off-the-shelf inception_v3 cannot cut off negative gradients during backward operation (issue #2).

demo2, demo3, and demo4 are hard-coded examples.

Examples

Demo 1

Generate all kinds of visualization maps given a torchvision model, a target layer, and images.

python main.py demo1 -a resnet152 \
                     -t layer4 \
                     -i samples/cat_dog.png

You can specify multiple images like:

python main.py demo1 -a resnet152 \
                     -t layer4 \
                     -i samples/cat_dog.png \
                     -i samples/vegetables.jpg
Predicted class #1 boxer #2 bull mastiff #3 tiger cat
Grad-CAM [1]
Vanilla backpropagation
"Deconvnet" [2]
Guided backpropagation [2]
Guided Grad-CAM [1]

Grad-CAM with different models for "bull mastiff" class

Model resnet152 vgg19 vgg19_bn densenet201 squeezenet1_1
Layer layer4 features features features features
Grad-CAM [1]

Demo 2

Generate Grad-CAM at different layers of ResNet-152 for "bull mastiff" class.

python main.py demo2 -i samples/cat_dog.png
Layer relu layer1 layer2 layer3 layer4
Grad-CAM [1]

Demo 3

Generate Grad-CAM with the original models. Here we use Xception v1 from my other repo and visualize the response at the last convolution layer (see demo3() for more details). If you want to adapt your own model, please verify the model uses only nn.ReLU, not F.relu.

python main.py demo3 -i samples/cat_dog.png
Predicted class #1 bull mastiff #2 tiger cat #3 boxer
Grad-CAM [1]

Demo 4

Generate the occlusion sensitivity map [1, 3] based on logit scores. The red and blue regions indicate a relative increase and decrease from non-occluded scores respectively: the blue regions are critical!

python main.py demo4 -a resnet152 -i samples/cat_dog.png
Patch size 10x10 15x15 25x25 35x35 45x45 90x90
"boxer" sensitivity
"bull mastiff" sensitivity
"tiger cat" sensitivity

This demo takes much time to compute per-pixel logits. You can control the resolution by changing sampling stride (--stride), or increasing batch size as to fit on your GPUs (--n-batches). The model is wrapped with torch.nn.DataParallel so that runs on multiple GPUs by default.

References

  1. R. R. Selvaraju, A. Das, R. Vedantam, M. Cogswell, D. Parikh, and D. Batra. Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. In ICCV, 2017
  2. J. T. Springenberg, A. Dosovitskiy, T. Brox, and M. Riedmiller. Striving for Simplicity: The All Convolutional Net. arXiv, 2014
  3. M. D. Zeiler, R. Fergus. Visualizing and Understanding Convolutional Networks. In ECCV, 2013

grad-cam-pytorch's People

Contributors

kazuto1011 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.