Coder Social home page Coder Social logo

capsnet-pytorch's Introduction

CapsNet-Pytorch

License

A Pytorch implementation of CapsNet in the paper:
Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017
The current average test error = 0.34% and best test error = 0.30%.

Differences with the paper:

  • We use the learning rate decay with decay factor = 0.9 and step = 1 epoch,
    while the paper did not give the detailed parameters (or they didn't use it?).
  • We only report the test errors after 50 epochs training.
    In the paper, I suppose they trained for 1250 epochs according to Figure A.1?
  • We use MSE (mean squared error) as the reconstruction loss and the coefficient for the loss is lam_recon=0.0005*784=0.392.
    This should be equivalent to using SSE (sum squared error) and lam_recon=0.0005 as in the paper.

TODO

  • Conduct experiments on other datasets.
  • Explore interesting characteristics of CapsuleNet.

Contacts

  • Your contributions to the repo are always welcome. Open an issue or contact me with E-mail [email protected] or WeChat wenlong-guo.

Usage

Step 1. Install Pytorch from source

I'm using the source code of Nov 24, 2017. The newest code should be working too.
Go https://github.com/pytorch/pytorch for installation instructions.

Step 2. Clone this repository to local.

git clone https://github.com/XifengGuo/CapsNet-Pytorch.git
cd CapsNet-Pytorch

Step 3. Train a CapsNet on MNIST

Training with default settings:

python capsulenet.py

Launching the following command for detailed usage:

python capsulenet.py -h

Step 4. Test a pre-trained CapsNet model

Suppose you have trained a model using the above command, then the trained model will be saved to result/trained_model.pkl. Now just launch the following command to get test results.

python capsulenet.py --testing --weights result/trained_model.pkl

It will output the testing accuracy and show the reconstructed images. The testing data is same as the validation data. It will be easy to test on new data, just change the code as you want.

You can also just download a model I trained from https://pan.baidu.com/s/1dFLFtT3

Results

Test Errors

CapsNet classification test error on MNIST. Average and standard deviation results are reported by 3 trials. The results can be reproduced by launching the following commands.

python capsulenet.py --routings 1 #CapsNet-v1   
python capsulenet.py --routings 3 #CapsNet-v2
Method Routing Reconstruction MNIST (%) Paper
Baseline -- -- -- 0.39
CapsNet-v1 1 yes 0.36 (0.016) 0.29 (0.011)
CapsNet-v2 3 yes 0.34 (0.029) 0.25 (0.005)

Losses and accuracies:

Training Speed

About 73s / epoch on a single GTX 1070 GPU.
About 43s / epoch on a single GTX 1080Ti GPU.

Reconstruction result

The result of CapsNet-v2 by launching

python capsulenet.py --testing --weights result/trained_model.pkl

Digits at top 5 rows are real images from MNIST and digits at bottom are corresponding reconstructed images.

Other Implementations

capsnet-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

capsnet-pytorch's Issues

Problem with keras.backend.batch_dot

I used to clone and run this repo, and it worked fine. I wrote a code based on your implementation of the capsule layer. However, today, when I run the code, I get this weird error. I tracked it, and it looks like there is a problem with batch_dot. Do you have any idea how to fix it? Even this repository after clone and run gives me the same error โ€” both on Google Colab and my local machine.
Thanks

ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 1152, 16) and (None, 10, None, 1152, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 1152).

in python 2.7 and 3.5

c = F.softmax(b, dim=1)

TypeError: softmax() got an unexpected keyword argument 'dim'

A bug in computing val_acc

val_acc is always zeros since "correct" is a tensor with int type. You should "correct" into a float type before computing "correct / len(test_loader.dataset)".

For example,
add this in line 127 in capsulenet.py : correct = correct.numpy().astype(float)

RuntimeError: value cannot be converted to type uint8_t without overflow: 10000

I use python 2.7 and cuda 9.1 to run the code. I met some warming and error.

capsulenet.py:165: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
training_loss += loss.data[0] * x.size(0) # record the batch loss
capsulenet.py:121: UserWarning: volatile was removed and now has no effect. Use with torch.no_grad(): instead.
x, y = Variable(x.cuda(), volatile=True), Variable(y.cuda())
capsulenet.py:123: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
test_loss += caps_loss(y, y_pred, x, x_recon, args.lam_recon).data[0] * x.size(0) # sum up batch loss
Traceback (most recent call last):
File "capsulenet.py", line 255, in
train(model, train_loader, test_loader, args)
File "capsulenet.py", line 169, in train
val_loss, val_acc = test(model, test_loader, args)
File "capsulenet.py", line 129, in test
return test_loss, correct / len(test_loader.dataset)
RuntimeError: value cannot be converted to type uint8_t without overflow: 10000

Do you have any idea how to solve it?

For 7 number of classes

def test(model, test_loader, args):
model.eval()
test_loss = 0
correct = 0
for x, y in test_loader:
y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.)
x, y = Variable(x.cuda(), volatile=True), Variable(y.cuda())

plz check the for loop statements. how the first line of y will be changed for 7 number of classes.

Construction of PrimaryCaps

I believe this line:

outputs = outputs.view(x.size(0), -1, self.dim_caps)

is not making the capsules in the right way. In theory if we do:

outputs = self.conv2d(x) outputs_2 = outputs.view(x.size(0), -1, self.dim_caps)

then outputs[0,0:8,0,0] should be equal to outputs_2[0,0,0:8]

and if you apply the view in that way this may not be guaranteed.

something about the loading model

when I have finished training the model. As your code shows, the model was save as pkl type. And when I load the load_state_dict(), the key could not compatible. like the following:
RuntimeError: Error(s) in loading state_dict for CapsuleNet:
Missing key(s) in state_dict: "conv1.bias", "conv1.weight", "primarycaps.conv2d.bias", "primarycaps.conv2d.weight", "digitcaps.weight", "decoder.0.bias", "decoder.0.weight", "decoder.2.bias", "decoder.2.weight", "decoder.4.bias", "decoder.4.weight".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.primarycaps.conv2d.weight", "module.primarycaps.conv2d.bias", "module.digitcaps.weight", "module.decoder.0.weight", "module.decoder.0.bias", "module.decoder.2.weight", "module.decoder.2.bias", "module.decoder.4.weight", "module.decoder.4.bias".

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.