jakesnell / prototypical-networks Goto Github PK
View Code? Open in Web Editor NEWCode for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"
License: MIT License
Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"
License: MIT License
Hi,
I am getting an error like this while training the network.
OSError Traceback (most recent call last)
/content/drive/My Drive/prototypical-network-pytorch-master/train.py in ()
66 ta = Averager()
67
---> 68 for i, batch in enumerate(train_loader, 1):
69 data, _ = [_.cuda() for _ in batch]
70 p = args.shot * args.train_way
3 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
427 # have message field
428 raise self.exc_type(message=msg)
--> 429 raise self.exc_type(msg)
430
431
OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/content/drive/My Drive/prototypical-network-pytorch-master/mini_imagenet.py", line 48, in getitem
image = self.transform(Image.open(path).convert('RGB'))
File "/usr/local/lib/python3.7/dist-packages/PIL/Image.py", line 2843, in open
fp = builtins.open(filename, "rb")
OSError: [Errno 5] Input/output error: './materials/images/n0450941700000967.jpg'
Exception: No images found for omniglot class Atemayar_Qelisayer/character12/rot180 at e:\prototypical-networks-master\protonets\data../../data/omniglot\data\Atemayar_Qelisayer\character12. Did you run download_omniglot.sh first?
There is no ‘rot180’ file in that omniglot dataset files. What should I do?
I realized that I misunderstood
In the paper, they said: "Each class comes with meta-data giving a high-level description of the class rather than a small number of labeled examples". Where we can find this meta-data?
Hi!
Thanks for your work. Could you relate how it should be done the work of inference over real test, I mean, I have a model trained and I am able to load from the checkpoint but how I can 'extract' the classes from my test dataset.
There is a problem running train。
FileNotFoundError: [Errno 2] No such file or directory: 'C:\Users\LENOVO\AppData\Local\conda\conda\envs\pytorch\lib\site-packages\protonets-0.0.1-py3.8.egg\protonets\data\../../data/omniglot\splits\vinyals\train.txt'
I sent you an email days ago about a research project.
In case the email ended up in your spam, check it out and contact me if you like.
I hope you can help me! ;)
Hi Jake,
I have been trying to implement Prototypical Networks (PN) for zero-shot learning and to reproduce your results on Caltech-UCSD Birds (CUB) dataset. However, some details are missing in the paper.
Could you please release your PN source code for zero-shot learning and your splits of the CUB dataset?
Best regards,
Nhan.
Thanks for your work, sir. I have a question about the following instruction:
Re-run in trainval mode python scripts/train/few_shot/run_trainval.py. This will save your model into results/trainval by default.
what does it mean? restart a training? or on the basis of the first running, then using its parameters train the model to find a better solution of parameters?
and why the results(embedding parameters) are equal when I run your code twice? due to the random seed? if not do that, can you guarantee the embedding results are equal?
In data/omniglot/splits/vinyals
, one alphabet (Gurmukhi) appears in both train.txt and test.txt.
I realise that this may not be a mistake if it corresponds to the original splits used. However, it should probably be fixed moving forward.
Note that none of the characters appear in both train
and test
sets, just that some characters from that alphabet are in train
and some are in test
.
(Hopefully it didn't affect any of the results since there are only 4 characters from this alphabet in the test
set -- not enough to do 5-way classification.)
Dear Jake,
I have been trying to reproduce your results for mini-imagenet, but there is a large gap between what I can get and what have been reported on the paper. I can get 47.21% for 5-way 1-shot, and 62.63% for 5-way 5-shot, while they should be 49.42 and 68.2. I have used both your code and my own implementation based on tensorflow. Also the code here (https://github.com/abdulfatir/prototypical-networks-tensorflow/blob/master/ProtoNet-MiniImageNet-v2.ipynb) gets similar results.
Is there any trick that I am missing? Can you point to something in the above link that should be changed to improve results? I also tried learning rate decay and it slightly helped but still a large gap.
Thanks in advance for your help;
Ali
I would like to ask, this paper is about classifying images. If it is about manipulating events, that is, text, what should be done?
I've tried two versions of prototypical-networks source code. Both of them get me this result, train loss is bigger and train accuracy is lower.
Why is it necessary to rotate each type of photo at different angles in the experiment? Are these photos in four categories?
Hello!
I am trying to replicate this experiment, but when I run the following command:
python scripts/train/few_shot/run_train.py --data.cuda --log.exp_dir results
I get the following error:
Traceback (most recent call last): File "scripts/train/few_shot/run_train.py", line 3, in <module> from train import main File "/gpfs/data/lcrawfo1/cnwizu/PythonSandbox/prototypical-networks-master/scripts/train/few_shot/train.py", line 16, in <module> import protonets.utils.data as data_utils ModuleNotFoundError: No module named 'protonets.utils'
I followed the readme exactly. What is the solution?
Thanks in advance!
Hi, I have some questions about this loss function because there are some differences between the paper and the code.
First of all, the paper:
The loss is the sum of the average logsoftmax between the query point and other prototypes and the average distance between the query point and the corresponding prototypes.
But in the code:
log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
The loss in code is only the sum of LogSoftmax between the query point and the corresponding prototypes.
I'm confused. Is it my understanding of the code or my understanding of the paper?
Hey guys!
I made an alternative implementation. I tried to keep it simple, but comparable.
There are scripts to automatically download Omniglot and mini-ImageNet.
And I followed the same experimental procedure.
The link is: https://github.com/giovcandido/prototypical-networks-project.
All you have to do is to run: sh exec_vanilla_omniglot.sh. Or: sh exec_vanilla_mini_imagenet.sh.
I made a notebook as well. The link is: https://github.com/giovcandido/prototypical-networks-jupyter.
I hope you enjoy it!
Hello, I want to ask where can I find the mian.py?
Thank you very much for sharing your code.
I want to train models with a dataset other than miniImageNet and Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?
Thanks a lot.
For example, the Setup is 1-Shot 5-Way
or something. the backbone is ResNet-12
or something.
Thank you very much.
Hello,Could you help me how to identify a new image class, on which the model is not trained on that class in query images as unknown one?
Thanks in advance
Thanks for your code. When I read your code, I have a question. In the few_shot.py z=self.encoder.forward(x) Is it the embedding process in the paper? Can you explain the embedding process in your code to me? This question may be fundamental because I a new beginner in machine learning. Hope your reply, thanks.
Hi guys,
I got a updated implementation which acheived over 79% acc on the MiniImageNet 5-shot task. It may be helpful for some guys in 2020.
https://github.com/Franklin-Yao/StrongProtoNet
Hello,
Thank you for sharing the code. Now I am preparing for the reproduction of your code using miniimagenet. Would you let me know the details about the dataset? For instance, Ravi et al, they released the wnid and its corresponding image list. If you have any, can you share to me?
Hi,
What parameters did you use to produce the t-SNE plot i.e. Figure 3 in the Appendix of https://arxiv.org/pdf/1703.05175.pdf
I trained weights well, how to throw a picture in when evaluating, which category does it tell me to belong to?
Hi Jake,
Prototypical networks is really a nice work.
I have run this code to reproduce the results in NIPS paper. However, it seems the results have some differences with the paper.
NIPS2017 paper:
5way1shot | 5way5shot | 20way1shot | 20way5shot |
---|---|---|---|
98.8 | 99.7 | 96.0 | 98.9 |
Reproduced results:
5way1shot | 5way5shot | 20way1shot | 20way5shot |
---|---|---|---|
98.4 | 99.6 | 94.9 | 98.6 |
I run this code several times and get similar results.
Can you release your hyper-parameter setting? Or is there any technical trick that may impact the performance?
Here is the cmds I used in 20way-1shot setting:
python scripts/train/few_shot/run_train.py --data.shot 1 --data.test_shot 1 --data.test_way 20 --data.cuda --log.exp_dir=results/20way1shot
python scripts/predict/few_shot/run_eval.py --data.test_shot 1 --data.test_way 20 --model.model_path=results/20way1shot/best_model.t7
Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.