Coder Social home page Coder Social logo

abdulfatir / prototypical-networks-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
132.0 4.0 44.0 196 KB

Tensorflow implementation of NIPS 2017 Paper "Prototypical Networks for Few-shot Learning"

Jupyter Notebook 91.23% Shell 5.46% Python 3.31%
few-shot one-shot-learning tensorflow nips-2017 omniglot deep-learning ipython-notebook miniimagenet

prototypical-networks-tensorflow's Introduction

Prototypical Networks for Few-shot Learning

Tensorflow implementation of NIPS 2017 Paper Prototypical Networks for Few-shot Learning[1].

This code has been ported from the official implementation in PyTorch (jakesnell/prototypical-networks) and may be buggy.

Usage

Omniglot Dataset

  • Download the Omniglot dataset by executing download_omniglot.sh
  • Use the IPython Notebook ProtoNet-Omniglot.ipynb

mini ImageNet v2

Downloading Images

  • Create an account on image-net.org with your institutional ID.
  • Replace <username> and <accesskey> in download_miniimagenet.sh with the username and accesskey you receive upon registration.
  • Run download_miniimagenet.sh which will download 84 ImageNet classes from ILSVRC2011. (64 train + 20 test)

Testing on miniImageNet

  • Run create_miniimagenet.py which will generate mini-imagenet-train.npy and mini-imagenet-test.npy which are numpy arrays of shapes 64 x 350 x 84 x 84 x 3 and 20 x 350 x 84 x 84 x 3 respectively.
  • Use the IPython Notebook ProtoNet-MiniImageNet-v2.ipynb.

NOTE: This miniImageNet dataset is not identical to the one used by Ravi et. al.[2] They have used images from ILSVRC2012 which can be downloaded from here. Ravi et. al. have used 100 classes (64 training + 16 validation + 20 test) with 600 examples from each class. The script provided in this dataset downloads images from image-net.org which currently (Feb, 2018) contains images from ILSVRC2011. Therefore, some of the classes suggested by Ravi et. al. have less than 600 examples. For this reason, the number of examples of each class has been reduced to 350. The scripts provided download images corresponding to 84 classes (64 train + 20 test), the ones suggested by Ravi et. al., and then randomly samples 350 examples for each class.

References

[1] Jake Snell, Kevin Swersky, and Richard S. Zemel. Prototypical networks for few-shot learning.
[2] Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning.

prototypical-networks-tensorflow's People

Contributors

abdulfatir avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

prototypical-networks-tensorflow's Issues

I have a problem about different class querry.

Hi, I'm studying your code, and I notice that during the test process, the proto net is very good at recognizing a new class(which is just learned) from all the classes in training set. However, when I need it to recognize the just learned class from other unseen classes, it performs very badly.

print('Testing NEGATIVE...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_test_classes)[:n_test_way+1]

    wrong = epi_classes[-1] ### I changed here, to generate a third new class for test querry, it's totally unseen before

    support = np.zeros([n_test_way, n_test_shot, im_height, im_width], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes[:-1]):
        selected = np.random.permutation(n_examples)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = np.concatenate((test_dataset[wrong,  selected[n_test_shot:-10]], test_dataset[epi_cls,selected[-10:]]))
    support = np.expand_dims(support, axis=-1)
    query = np.expand_dims(query, axis=-1)
    labels = np.tile(np.arange(n_test_way)[:, np.newaxis], (1, n_test_query)).astype(np.uint8)
    ls, ac = sess.run([ce_loss, acc], feed_dict={x: support, q: query, y:labels})
    avg_acc += ac
    if (epi+1) % 50 == 0:
        print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(epi+1, n_test_episodes, ls, ac))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))

The performance is :

Testing NEGATIVE...
Average Test Accuracy: 0.54737


Please correct me if I understand it wrong.
Looking forward to your reply, and thank you for your help in advance.

I haved run this network and the accuracy I got is only 0.19. Is this normal?

Hi ,
I haved run this network and the accuracy I got is only 0.19. I want to ask is this result normal?
The log is below.
[epoch 97/100, episode 50/100] => loss: 2.65600, acc: 0.18667
[epoch 97/100, episode 100/100] => loss: 2.57098, acc: 0.21333
[epoch 98/100, episode 50/100] => loss: 2.66116, acc: 0.17000
[epoch 98/100, episode 100/100] => loss: 2.66873, acc: 0.18000
[epoch 99/100, episode 50/100] => loss: 2.58527, acc: 0.24667
[epoch 99/100, episode 100/100] => loss: 2.60647, acc: 0.20333
[epoch 100/100, episode 50/100] => loss: 2.50875, acc: 0.26000
[epoch 100/100, episode 100/100] => loss: 2.62107, acc: 0.18000
(20, 350, 84, 84, 3)
Testing...
[test episode 50/600] => loss: 1.73374, acc: 0.17333
[test episode 100/600] => loss: 1.73245, acc: 0.16000
[test episode 150/600] => loss: 1.65328, acc: 0.21333
[test episode 200/600] => loss: 1.68278, acc: 0.17333
[test episode 250/600] => loss: 1.64363, acc: 0.25333
[test episode 300/600] => loss: 1.77591, acc: 0.12000
[test episode 350/600] => loss: 1.58832, acc: 0.25333
[test episode 400/600] => loss: 1.66476, acc: 0.17333
[test episode 450/600] => loss: 1.71281, acc: 0.16000
[test episode 500/600] => loss: 1.64050, acc: 0.24000
[test episode 550/600] => loss: 1.71050, acc: 0.21333
[test episode 600/600] => loss: 1.72494, acc: 0.20000
Average Test Accuracy: 0.19098

I didn't use download_miniimagenet.sh to download the dataset because of the 404 error. So I download the dataset ILSVRC2012_img_train_t3 from {http://www.image-net.org/challenges/LSVRC/2012/nonpub-downloads}. And I select 64 classes from the dataset and put them into the directory {prototypical-networks-tensorflow/data/mini-imagenet/data/train}, and 20 classes from the dataset and put them into the directory {prototypical-networks-tensorflow/data/mini-imagenet/data/test}. Does this create a problem?
PS: I run the .py file instead of .ipynb file.
Thank you so much! I am looking forward to your reply.

Reproducing Paper's Results

Dear Abdul,

Your implementation of the mini-imagenet experiment yields an accuracy of 0.61 while it is reported to be 0.68 in the paper. What do you think the issue is? Have you tried to close the gap?

Thanks;
Ali

For other datasets

Assalamualaykum,
I wanted to run the code for Tiny Imagenet 200 by Stanford University. Can you please guide on how to generate the 'npy' files?

Hello, Same problem to me. Have you solved?

I ran Proto-MiniImagenet and I got followings
[epoch 93/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 94/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 94/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 95/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 95/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 96/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 96/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 97/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 97/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 98/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 98/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 99/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 99/100, episode 100/100] => loss: 2.99573, acc: 0.05000
[epoch 100/100, episode 50/100] => loss: 2.99573, acc: 0.05000
[epoch 100/100, episode 100/100] => loss: 2.99573, acc: 0.05000

same accuracy for every episode.
And weird thing is that this wasn't happening for Proto-Omniglot.
My tensorflow is for GPU and version 1.3

Originally posted by @themis0888 in #1 (comment)

How to detect as an unknown face?

Hello,Thank you for sharing your work in detail.Could you please help me how to identify unknown face in query images as we don't have threshold in softmax outputs?.

Thanks in advance

the "layers.batch_norm"

awesome job! I have reproduced the same result on the mnist dataset. but some issues still annoys me.

  1. I wondered if the "is_training" parameter of function "layers.batch_norm " should be set to "False" during the inference.
  2. if the L2 or any regularizer is needed,

Thx!!!!!!!

I have some problem.

I currently use your code about ProtoNet-Omniglot.ipynb

I have not changed the code, but the accuracy and loss value is not changed.

I use tensorflow 1.3

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.