Coder Social home page Coder Social logo

jakesnell / prototypical-networks Goto Github PK

View Code? Open in Web Editor NEW
1.1K 13.0 253.0 206 KB

Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"

License: MIT License

Python 98.05% Shell 1.95%
few-shot deep-learning metric-learning nips-2017 omniglot pytorch

prototypical-networks's Introduction

Prototypical Networks for Few-shot Learning

Code for the NIPS 2017 paper Prototypical Networks for Few-shot Learning.

If you use this code, please cite our paper:

@inproceedings{snell2017prototypical,
  title={Prototypical Networks for Few-shot Learning},
  author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
  booktitle={Advances in Neural Information Processing Systems},
  year={2017}
 }

Training a prototypical network

Install dependencies

  • This code has been tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.4.
  • Install PyTorch and torchvision.
  • Install torchnet by running pip install git+https://github.com/pytorch/tnt.git@master.
  • Install the protonets package by running python setup.py install or python setup.py develop.

Set up the Omniglot dataset

  • Run sh download_omniglot.sh.

Train the model

  • Run python scripts/train/few_shot/run_train.py. This will run training and place the results into results.
    • You can specify a different output directory by passing in the option --log.exp_dir EXP_DIR, where EXP_DIR is your desired output directory.
    • If you are running on a GPU you can pass in the option --data.cuda.
  • Re-run in trainval mode python scripts/train/few_shot/run_trainval.py. This will save your model into results/trainval by default.

Evaluate

  • Run evaluation as: python scripts/predict/few_shot/run_eval.py --model.model_path results/trainval/best_model.pt.

prototypical-networks's People

Contributors

jakesnell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

prototypical-networks's Issues

How do I make a prediction

Hi!
Thanks for your work. Could you relate how it should be done the work of inference over real test, I mean, I have a model trained and I am able to load from the checkpoint but how I can 'extract' the classes from my test dataset.

What if it's an event

I would like to ask, this paper is about classifying images. If it is about manipulating events, that is, text, what should be done?

Error while training

Hi,
I am getting an error like this while training the network.

OSError Traceback (most recent call last)
/content/drive/My Drive/prototypical-network-pytorch-master/train.py in ()
66 ta = Averager()
67
---> 68 for i, batch in enumerate(train_loader, 1):
69 data, _ = [_.cuda() for _ in batch]
70 p = args.shot * args.train_way

3 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
427 # have message field
428 raise self.exc_type(message=msg)
--> 429 raise self.exc_type(msg)
430
431

OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/content/drive/My Drive/prototypical-network-pytorch-master/mini_imagenet.py", line 48, in getitem
image = self.transform(Image.open(path).convert('RGB'))
File "/usr/local/lib/python3.7/dist-packages/PIL/Image.py", line 2843, in open
fp = builtins.open(filename, "rb")
OSError: [Errno 5] Input/output error: './materials/images/n0450941700000967.jpg'

Classification

I trained weights well, how to throw a picture in when evaluating, which category does it tell me to belong to?

How to train a new dataset?

Thank you very much for sharing your code.

I want to train models with a dataset other than miniImageNet and Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?

Thanks a lot.

question of a detail

Thanks for your work, sir. I have a question about the following instruction:

Re-run in trainval mode python scripts/train/few_shot/run_trainval.py. This will save your model into results/trainval by default.

what does it mean? restart a training? or on the basis of the first running, then using its parameters train the model to find a better solution of parameters?

and why the results(embedding parameters) are equal when I run your code twice? due to the random seed? if not do that, can you guarantee the embedding results are equal?

Can you release detailed configuration?

Hi Jake,
Prototypical networks is really a nice work.

I have run this code to reproduce the results in NIPS paper. However, it seems the results have some differences with the paper.

NIPS2017 paper:

5way1shot 5way5shot 20way1shot 20way5shot
98.8 99.7 96.0 98.9

Reproduced results:

5way1shot 5way5shot 20way1shot 20way5shot
98.4 99.6 94.9 98.6

I run this code several times and get similar results.
Can you release your hyper-parameter setting? Or is there any technical trick that may impact the performance?

Here is the cmds I used in 20way-1shot setting:

python scripts/train/few_shot/run_train.py --data.shot 1 --data.test_shot 1 --data.test_way 20 --data.cuda --log.exp_dir=results/20way1shot 
python scripts/predict/few_shot/run_eval.py --data.test_shot 1 --data.test_way 20 --model.model_path=results/20way1shot/best_model.t7 

Thanks.

accuracy when training with 1 query set

When training with 5 classes, if the query set is composed of 1 per class, train acc is calculated as 0.2 and does not increase.
Even during evaluation, if the query set is composed of one per class, the accuracy is output as 0.2 +/- 0.
What is the cause of the problem and how can it be solved?
캡처

About calculation of loss function

Hi, I have some questions about this loss function because there are some differences between the paper and the code.
First of all, the paper:
Screen Shot 2021-11-27 at 19 26 52
The loss is the sum of the average logsoftmax between the query point and other prototypes and the average distance between the query point and the corresponding prototypes.

But in the code:

       log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
       loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

The loss in code is only the sum of LogSoftmax between the query point and the corresponding prototypes.

I'm confused. Is it my understanding of the code or my understanding of the paper?

omniglot: Same alphabet appears in "train" and "test"

In data/omniglot/splits/vinyals, one alphabet (Gurmukhi) appears in both train.txt and test.txt.

I realise that this may not be a mistake if it corresponds to the original splits used. However, it should probably be fixed moving forward.

Note that none of the characters appear in both train and test sets, just that some characters from that alphabet are in train and some are in test.

(Hopefully it didn't affect any of the results since there are only 4 characters from this alphabet in the test set -- not enough to do 5-way classification.)

Research project about PN

I sent you an email days ago about a research project.

In case the email ended up in your spam, check it out and contact me if you like.

I hope you can help me! ;)

Zero-shot learning meta-data

In the paper, they said: "Each class comes with meta-data giving a high-level description of the class rather than a small number of labeled examples". Where we can find this meta-data?

Some questions

Thanks for your code. When I read your code, I have a question. In the few_shot.py z=self.encoder.forward(x) Is it the embedding process in the paper? Can you explain the embedding process in your code to me? This question may be fundamental because I a new beginner in machine learning. Hope your reply, thanks.

Installation of setup under windows

There is a problem running train。
FileNotFoundError: [Errno 2] No such file or directory: 'C:\Users\LENOVO\AppData\Local\conda\conda\envs\pytorch\lib\site-packages\protonets-0.0.1-py3.8.egg\protonets\data\../../data/omniglot\splits\vinyals\train.txt'

Alternative implementation

Hey guys!

I made an alternative implementation. I tried to keep it simple, but comparable.
There are scripts to automatically download Omniglot and mini-ImageNet.
And I followed the same experimental procedure.

The link is: https://github.com/giovcandido/prototypical-networks-project.

All you have to do is to run: sh exec_vanilla_omniglot.sh. Or: sh exec_vanilla_mini_imagenet.sh.

I made a notebook as well. The link is: https://github.com/giovcandido/prototypical-networks-jupyter.

I hope you enjoy it!

Reproducing Mini-Imagenet Results

Dear Jake,

I have been trying to reproduce your results for mini-imagenet, but there is a large gap between what I can get and what have been reported on the paper. I can get 47.21% for 5-way 1-shot, and 62.63% for 5-way 5-shot, while they should be 49.42 and 68.2. I have used both your code and my own implementation based on tensorflow. Also the code here (https://github.com/abdulfatir/prototypical-networks-tensorflow/blob/master/ProtoNet-MiniImageNet-v2.ipynb) gets similar results.

Is there any trick that I am missing? Can you point to something in the above link that should be changed to improve results? I also tried learning rate decay and it slightly helped but still a large gap.

Thanks in advance for your help;
Ali

An error in paper

image

I think there is an error where I draw the line.Of course,maybe I misunderstand.

Reproducing Caltech-UCSD Birds (CUB) Results

Hi Jake,

I have been trying to implement Prototypical Networks (PN) for zero-shot learning and to reproduce your results on Caltech-UCSD Birds (CUB) dataset. However, some details are missing in the paper.

Could you please release your PN source code for zero-shot learning and your splits of the CUB dataset?

Best regards,
Nhan.

dataset questions

Exception: No images found for omniglot class Atemayar_Qelisayer/character12/rot180 at e:\prototypical-networks-master\protonets\data../../data/omniglot\data\Atemayar_Qelisayer\character12. Did you run download_omniglot.sh first?

There is no ‘rot180’ file in that omniglot dataset files. What should I do?

ModuleNotFoundError: No module named 'protonets.utils'

Hello!

I am trying to replicate this experiment, but when I run the following command:
python scripts/train/few_shot/run_train.py --data.cuda --log.exp_dir results

I get the following error:

Traceback (most recent call last): File "scripts/train/few_shot/run_train.py", line 3, in <module> from train import main File "/gpfs/data/lcrawfo1/cnwizu/PythonSandbox/prototypical-networks-master/scripts/train/few_shot/train.py", line 16, in <module> import protonets.utils.data as data_utils ModuleNotFoundError: No module named 'protonets.utils'

I followed the readme exactly. What is the solution?

Thanks in advance!

Miniimagenet files

Hello,

Thank you for sharing the code. Now I am preparing for the reproduction of your code using miniimagenet. Would you let me know the details about the dataset? For instance, Ravi et al, they released the wnid and its corresponding image list. If you have any, can you share to me?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.