Coder Social home page Coder Social logo

cnnvis's Introduction

drawing

Welcome

CNNVis is a high-level convolutional neural network (CNN) visualization API built on top of Keras. The intention behind this project aligns with the intention of Keras: "Being able to go from idea to result with the least possible delay is key to doing good research".

Use CNNVis if you need to visualize the following aspects of a CNN. Of course, your CNN model needs to be a keras.models.Sequential or keras.models.Model instance.

Feel free to email me at [email protected] about any additional features that you would like to visualize!

The main resources that helped the development of this project tremendously:

  • Chapter 5: Computer Vision, of book Deep Learning with Python by Francis Chollet
  • The official Keras documentation: https://keras.io/

Important Note: This library only supports your Keras if your backend is TensorFlow.

Getting Started: 30 seconds to CNNVis

First, make sure that all dependencies are installed:

  • prettytable
  • Numpy
  • matplotlib
  • PIL
  • Keras

First Step: Instantiate a Visualizer instance

To instantiate a Visualizer instance for a vgg16 network:

import keras
from cnnvis import Visualizer

vgg16_model = keras.applications.VGG16(weights='imagenet', include_top=True)
visualizer = Visualizer(model=vgg16_model, image_shape=(224, 224, 3), batch_size=1, preprocess_style='vgg16')

Print summary

To print the default summary:

visualizer.summary(style='default')

To print the "cnn style" summary:

visualizer.summary(style='cnn')
CNN Style Model Summary
+--------------+--------------+------------+-------------+----------------+---------------------+
|  Layer Name  |  Layer Type  | Kernel Num | Kernel Size | Kernel Padding |     Output Shape    |
+--------------+--------------+------------+-------------+----------------+---------------------+
| block1_conv1 |    Conv2D    |     64     |    (3, 3)   |      same      |  (1, 224, 224, 64)  |
| block1_conv2 |    Conv2D    |     64     |    (3, 3)   |      same      |  (1, 224, 224, 64)  |
| block1_pool  | MaxPooling2D |     /      |    (2, 2)   |       /        |   (1, 112, 112, 3)  |
| block2_conv1 |    Conv2D    |    128     |    (3, 3)   |      same      |  (1, 224, 224, 128) |
| block2_conv2 |    Conv2D    |    128     |    (3, 3)   |      same      |  (1, 224, 224, 128) |
| block2_pool  | MaxPooling2D |     /      |    (2, 2)   |       /        |   (1, 112, 112, 3)  |
| block3_conv1 |    Conv2D    |    256     |    (3, 3)   |      same      |  (1, 224, 224, 256) |
| block3_conv2 |    Conv2D    |    256     |    (3, 3)   |      same      |  (1, 224, 224, 256) |
| block3_conv3 |    Conv2D    |    256     |    (3, 3)   |      same      |  (1, 224, 224, 256) |
| block3_pool  | MaxPooling2D |     /      |    (2, 2)   |       /        |   (1, 112, 112, 3)  |
| block4_conv1 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block4_conv2 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block4_conv3 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block4_pool  | MaxPooling2D |     /      |    (2, 2)   |       /        |   (1, 112, 112, 3)  |
| block5_conv1 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block5_conv2 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block5_conv3 |    Conv2D    |    512     |    (3, 3)   |      same      |  (1, 224, 224, 512) |
| block5_pool  | MaxPooling2D |     /      |    (2, 2)   |       /        |   (1, 112, 112, 3)  |
|   flatten    |   Flatten    |     /      |      /      |       /        |     (1, 150528)     |
|     fc1      |    Dense     |     /      |      /      |       /        | (1, 224, 224, 4096) |
|     fc2      |    Dense     |     /      |      /      |       /        | (1, 224, 224, 4096) |
| predictions  |    Dense     |     /      |      /      |       /        | (1, 224, 224, 1000) |
+--------------+--------------+------------+-------------+----------------+---------------------+
Training set: 1000 Categories of ImageNet
Number of Conv2D layers: 13
Number of MaxPooling2D layers: 5
Number of Dense layers: 3

Plot saliency map

To plot saliency maps:

img_paths = ['fish.jpg', 'bird.jpg', 'elephants.jpg']
saliency_maps = visualizer.get_saliency_map(img_paths)  
import numpy as np
from matplotlib import pyplot as plt

fig = plt.figure()
fig.add_subplot(1, 3, 1)
plt.axis('off')
plt.imshow(saliency_maps[0])

fig.add_subplot(1, 3, 2)
plt.axis('off')
plt.imshow(saliency_maps[1])

fig.add_subplot(1, 3, 3)
plt.axis('off')
plt.imshow(saliency_maps[2])

plt.show()

drawing

Plot feature map

To plot the feature map of a specific layer to a specific image (e.g. giraffe):

drawing

feature_map = visualizer.get_feature_maps(['block5_conv3'], ['giraffe.png'])
import numpy as np
from matplotlib import pyplot as plt

plt.matshow(np.mean(feature_map[0, 0], axis=-1))
plt.show()

drawing

Plot mean activations

To plot mean activations of multiple layers to multiple images (e.g. a cat image and a dog image):

mean_activation = visualizer.get_mean_activations(['block5_conv2', 'block5_conv3'], [img_path_cat, img_path_dog])
from matplotlib import pyplot as plt

plt.plot(mean_activation[0, 0], label='Cat', alpha=0.6)
plt.plot(mean_activation[0, 1], label='Dog', alpha=0.6)
plt.xlabel('Kernel Index')
plt.ylabel('Mean Activation')
plt.title('Mean activations of block5_conv2')
plt.legend()
plt.show()

drawing

from matplotlib import pyplot as plt

plt.plot(mean_activation[1, 0], label='Cat', alpha=0.6)
plt.plot(mean_activation[1, 1], label='Dog', alpha=0.6)
plt.xlabel('Kernel Index')
plt.ylabel('Mean Activation')
plt.title('Mean activations of block5_conv3')
plt.legend()
plt.show()

drawing

Plot max activation

To plot max activation to specific kernels in a specific layer:

max_activations = visualizer.get_max_activations('block3_conv1', [12, 123], 2)
from matplotlib import pyplot as plt

plt.imshow(max_activation[0])
plt.axis('off')
plt.show()

drawing

from matplotlib import pyplot as plt

plt.imshow(max_activation[1])
plt.axis('off')
plt.show()

drawing

Plot kernel

To plot kernels

kernels = visualizer.get_kernels('block2_conv1')
from matplotlib import pyplot as plt
import numpy as np

plt.matshow(np.mean(kernels[:, :, :, 1], axis=-1))
plt.show()

drawing

cnnvis's People

Contributors

zhihanyang2022 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.