Coder Social home page Coder Social logo

cezannec / capsule_net_pytorch Goto Github PK

View Code? Open in Web Editor NEW
366.0 20.0 132.0 1.77 MB

Readable implementation of a Capsule Network as described in "Dynamic Routing Between Capsules" [Hinton et. al.]

License: MIT License

Jupyter Notebook 99.88% Python 0.12%
pytorch pytorch-tutorial capsule-network

capsule_net_pytorch's Introduction

Capsule Network

Readable implementation of a Capsule Network as described in "Dynamic Routing Between Capsules" [Hinton et. al.]

In this notebook, I'll be building a simple Capsule Network that aims to classify MNIST images. This is an implementation in PyTorch and this notebook assumes that you are already familiar with convolutional and fully-connected layers.

What are Capsules?

Capsules are a small group of neurons that have a few key traits:

  • Each neuron in a capsule represents various properties of a particular image part; properties like a parts color, width, etc.
  • Every capsule outputs a vector, which has some magnitude (that represents a part's existence) and orientation (that represents a part's generalized pose).
  • A capsule network is made of multiple layers of capsules; during training, this network aims to learn the spatial relationships between the parts and whole of an object (ex. how the position of eyes and a nose relate to the position of a whole face in an image).
  • Capsules represent relationships between parts of a whole object by using dynamic routing to weight the connections between one layer of capsules and the next and creating strong connections between spatially-related object parts.

You can read more about all of these traits in my blog post about capsules and dynamic routing.

Representing Relationships Between Parts

All of these traits allow capsules to communicate with each other and determine how data moves through them. Using dynamic communication, during the training process, a capsule network learns the spatial relationships between visual parts and their wholes (ex. between eyes, a nose, and a mouth on a face). When compared to a vanilla CNN, this knowledge about spatial relationships makes it easier for a capsule network to identify an object no matter what orientation it is in. These networks are also, generally, better able to identify multiple, overlapping objects, and to learn from smaller sets of training data!


Model Architecture

The Capsule Network that I'll define is made of two main parts:

  1. A convolutional encoder
  2. A fully-connected, linear decoder

The above image was taken from the original Capsule Network paper (Hinton et. al.). The notebook follows the architecture described in that paper and tries to replicate some of the experiments, such as feature visualization, that the authors pursued.


Running Code Locally

If you're interested in running this code on your own computer, there are thorough instructions on setting up anaconda, and downloading PyTorch and the necessary libraries in the readme of Udacity's deep learning repo. After downloading the necessary libraries, you can proceed with cloning and running this code, as usual.

capsule_net_pytorch's People

Contributors

cezannec avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

capsule_net_pytorch's Issues

There may be wrong for computing the softmax of digCaps

Thank you for your contribution. I find in the code that >c_ij` = helpers.softmax(b_ij, dim=2)
which computs the softmax of PrimaryCaps, however, in original paper, the c_ij is the softmax of DigCaps.
and I think that softmax logits [10, Batchsize, 1152, 1, 16] should apply to dim 0. probs = softmax ( logits, dim=0 ) as the original paper presents.

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Hello,

I am very new to capsuleNetworks and Pytorch in general. Thank you for the detailed and easy to understand explanations. While I was trying to run the code I came across an error when I was trying to train a model.

RuntimeError                              Traceback (most recent call last)
<ipython-input-16-ce644b7b7998> in <module>
      1 # training for 3 epochs
      2 n_epochs = 3
----> 3 losses = train(capsule_net, criterion, optimizer, n_epochs=n_epochs)

<ipython-input-15-54eb5db28cd7> in train(capsule_net, criterion, optimizer, n_epochs, print_every)
     34             optimizer.zero_grad()
     35             # get model outputs
---> 36             caps_output, reconstructions, y = capsule_net(images)
     37             # calculate loss
     38             loss = criterion(caps_output, target, images, reconstructions)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-11-1bf5514185a1> in forward(self, images)
     16         primary_caps_output = self.primary_capsules(self.conv_layer(images))
     17         caps_output = self.digit_capsules(primary_caps_output).squeeze().transpose(0,1)
---> 18         reconstructions, y = self.decoder(caps_output)
     19         return caps_output, reconstructions, y
     20 

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-10-ffae35828f12> in forward(self, x)
     44         x = x * y[:, :, None]
     45         # flatten image into a vector shape (batch_size, vector_dim)
---> 46         flattened_x = x.view(x.size(0), -1)
     47         # create reconstructed image vectors
     48         reconstructions = self.linear_layers(flattened_x)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

I have not changed any part of the code yet. I wanted to run the code as it is, before trying different things. Can you help me understand why such an error was caused and how to fix it?

Thank you!

EDIT: I just replaced the view function with reshape as suggested in the error and it works. Though I am still not sure of the difference between the two functions in this context.

Squash function

Very nice tutorial, though I want to point out that the squash formula in the notebook differs from the paper. Instead of

squash

it should be
squash-from-paper

, so the first fraction is a factor a slightly below 1, and the second one nozmalizes vector coordinates by the magnitude.

As far as I can see, the implementation follows the second formula and it seems to be correct, except that I am not sure about the normalization dimension for primary capsules. According to the explanations from the notebook, each primary capsule outputs a vector of size 32 * 6 * 6. Then these vectors are stacked and, considering the batch dimension, we get a tensor of the shape

(batch_size, num_nodes_in_capsule = 32 * 6 * 6, num_capsules = 8)

Finally, these vectors are normalized, i.e. their magnitudes are squashed to be in the range from 0 to 1. If I understand correctly, you are talking about the magnitude of the (32 * 6 * 6)-dimensional vectors. So if we want to ensure that the length of these vectors is in range [0; 1], we would have to divide each of the (36 * 6 * 6) coordinates by the square root of the sum of squares of these coordinates. Right? In fact, the implementation divides each coordinate by the magnitude of a vector comprised of the coordinates in the same positions of all capsule vectors. See dim is set to -1 when calculating squared_norm, i.e. it sums up same features, but from different capsules.

Please, consider the following example:

import torch
import numpy as np

def squash(input_tensor):
    '''Squashes an input Tensor so it has a magnitude between 0-1.
        param input_tensor: a stack of capsule inputs, s_j
        return: a stack of normalized, capsule output vectors, v_j
        '''
    squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True)    
    scale = squared_norm / (1 + squared_norm) # normalization coeff
    output_tensor = scale * input_tensor / torch.sqrt(squared_norm)    
    return output_tensor

np.random.seed(1)
torch.manual_seed(1)
batch_size = 15
dim=13
n_caps = 7
u = [torch.tensor(np.random.rand(batch_size, dim, 1)) for i in range(n_caps) ] 
#print(u)
u = torch.cat(u, dim=-1)
print("u:", u)

u_squash = squash(u)
print("u_squash:", u_squash)

mag = torch.sqrt( (u_squash **2).sum(dim=-2) )
print("mag: ", mag)

Here I create a randomly filled tensor of shape (batch_size, dim, n_caps), i.e. similar to those produced by the primary capsules. The tensor is squashed by the same function used in the notebook. It can be seen from the output that the magnitudes of the vectors exceeds the range [0; 1]:

mag:  tensor([[0.6629, 1.0954, 0.9715, 0.7817, 1.0211, 0.7117, 0.8847],
        [1.0202, 0.9313, 0.8816, 0.8383, 1.0355, 0.9926, 1.0803],
        [0.8864, 1.0694, 0.7617, 0.9194, 0.8355, 0.9432, 1.0051],
        [0.9630, 0.9198, 0.9078, 1.0516, 0.8845, 0.7888, 0.9238],
        [0.6996, 1.0998, 1.1319, 0.6556, 0.8243, 0.9571, 0.9614],
        [0.9705, 0.9879, 0.8915, 0.8308, 1.0063, 1.0607, 0.9306],
        [1.0569, 1.0294, 0.9268, 1.0508, 0.9768, 0.9505, 0.8103],
        [0.9545, 0.9655, 0.9052, 1.0720, 0.7246, 0.9666, 0.9669],
        [1.1237, 0.9768, 0.9749, 0.8128, 0.8935, 0.9216, 0.7607],
        [0.8785, 0.7155, 0.8306, 0.8913, 0.9764, 0.9692, 1.0892],
        [0.9691, 0.8658, 1.0399, 0.9774, 0.9309, 0.8950, 0.8872],
        [0.7124, 1.1386, 0.8535, 1.0913, 0.8478, 0.8779, 0.9850],
        [0.8909, 0.9851, 0.9247, 1.0239, 0.7927, 0.9618, 0.7925],
        [0.8764, 0.9524, 0.9294, 0.8517, 0.8385, 0.9380, 1.0824],
        [1.0076, 0.8668, 1.0051, 0.9030, 1.0067, 0.8850, 0.9519]],
       dtype=torch.float64)

It actually enforces the magnitudes of vectors comprised of particular coordinates from different capsule outputs to be in that range. But was that intended?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.