oscarknagg / few-shot Goto Github PK
View Code? Open in Web Editor NEWRepository for few-shot learning machine learning projects
License: MIT License
Repository for few-shot learning machine learning projects
License: MIT License
Do I need to rewrite the forward function for matching net work myself and why is it a pass in the code, or it's suggested to call with model.f(), model.g()?
I used to believe in k-way-n-shot few-shot learning, k and n (number of classes and samples from each class respectively) must be the same in train and test phases. But you uses different numbers in the train and test phase (60 for train and 5 for test):
parser.add_argument('--dataset')
parser.add_argument('--distance', default='l2')
parser.add_argument('--n-train', default=1, type=int)
parser.add_argument('--n-test', default=1, type=int)
parser.add_argument('--k-train', default=60, type=int)
parser.add_argument('--k-test', default=5, type=int)
parser.add_argument('--q-train', default=5, type=int)
parser.add_argument('--q-test', default=1, type=int)
Are we allowed to do so?
Thank you very much for sharing your code.
I want to train models with a dataset other than miniImageNet and Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?
Thanks a lot.
as the title, i can not get the result of any model, i can not find the meaning of this implement.
Hi, I want to reproduce results for test set after training the MAML. Is there any modules for testing?
On the 5-way 1-shot task of Omniglot there are two metrics in the CSV results file: categorical_accuracy and val_1-shot_5-way_acc.
I want to confirm which metric is used to evaluate the performace of the model on the dataset: categorical_accuracy or
val_1-shot_5-way_acc?
Thanks for sharing this few-shot learning repository, which help me a lot.
I am curious and confused why using dummy data to calculate loss and update the meta model in line 110 of file few_shot/maml.py
, here I enclose the code:
# Dummy pass in order to create `loss` variable
# Replace dummy gradients with mean task gradients using hooks
logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
loss.backward()
optimiser.step()
As I understand it, the test data should be applied to update the model, using dummy data to train the model seems unreasonable. I will be appreciated it if someone can answer this question.
when i train the model ,it occurd KeyError: 'class_name'
Traceback (most recent call last):
File "/home/cc/anaconda3/envs/HowToTrainYourMAMLPytorch-master/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2889, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 97, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1675, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1683, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'class_name'
please tell me methods of resolution,thank you very much
Hi, thank you for a well-packed blog post and implementation on meta-learning!
I have a question about the implementation of the matching network.
I see that you don't resize Omniglot image, and let an example has a size of 105x105.
This results in the runtime error in the following line, because the input size of LSTM is 64,but the size of 'embedding' is 2304(after self.encoder). The size becomes 64 if the original image is resized to 28x28.
Line 69 in eab3c78
This implementation also resizes Omniglot into 28 * 28.
Do you not get this runtime error with the current version of the code? Should I resize image in this part?
Line 39 in eab3c78
Dear Oscar Knagg,
Firstly, I thank for your helpful codes. However, I tried to download your scripts, run again on my computer and raised an error at
few_shot/core.py", line 80, in iter
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
ValueError: a must be greater than 0
Besides, in the maml, I see
order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set)
or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights).
but the corresponding program is:
`python
if order == 1:
if train:
sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
for k in task_gradients[0].keys()}
hooks = []
for name, param in model.named_parameters():
hooks.append(
param.register_hook(replace_grad(sum_task_gradients, name))
)
model.train()
optimiser.zero_grad()
# Dummy pass in order to create `loss` variable
# Replace dummy gradients with mean task gradients using hooks
logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
loss.backward()
optimiser.step()
for h in hooks:
h.remove()
return torch.stack(task_losses).mean(), torch.cat(task_predictions)
elif order == 2:
model.train()
optimiser.zero_grad()
meta_batch_loss = torch.stack(task_losses).mean()
if train:
meta_batch_loss.backward()
optimiser.step()
`
aren't the "order==2" part doing the "update meta-learner weights with gradients of the updated weights on the query set" job?
and I can't understand your code in "order==1" part. Can you explain it with more details again?
Any response would be appreciated! Thanks for your time!
I trained the model. Now i want to do inference with the trainedmode. Can you please help on this.
@oscarknagg do the needful
why the parameter setting of protonet for 5way 1 shot is 20 k train and 1 shot?
Hi,
Can you please share the hyperparameters used to reproduce results on miniImageNet. I tried using the parameters from experiments.txt
but couldn't get these numbers.
Thanks.
Hi Oscar,
firstly, thank you for sharing your code with all of us. I noticed you encountered the same trouble I had while I was trying to implement the f Full Context Embedding following the original paper description.
Line 246 in 672de83
I build a LSTM model from scratch that could allow you to fix this problem.
You can find the model here: https://github.com/rtorrisi/LSTM-Custom-InOut-Hidden-Size
Let me know your thoughts.
Kind regards,
Riccardo
When installing requirements, I got this error. My current Python version is 3.8.8
There is a bug related to class names. Using your code in scripts/prepare_omniglot.py, I got class names like 'Alphabet_of_the_Magi.0.character03'. But I think 'Alphabet_of_the_Magi.0.character03' and 'Alphabet_of_the_Magi.90.character03' have same classes because they are augmented data which have same labels(classes).
So if you change the following line
https://github.com/oscarknagg/few-shot/blob/master/few_shot/datasets.py#L80
to
alphabet = root.split('/')[-2].split('.')[0]
In the code, the model is passed to Evaluation phase by using set_model() function:
def set_model(self, model):
for callback in self.callbacks:
callback.set_model(model)
Does it work correctly if we only copy the weights/state_dict of the model and optimizer? (In the case that we save the model weight and then finetune with new tasks) Is there any loss information when we save the model state_dict?
How to use pytest to predict the test code in the tests directory and display the running result.
!python /content/few-shot/scripts/prepare_omniglot.py
Traceback (most recent call last):
File "/content/few-shot/scripts/prepare_omniglot.py", line 20, in
from few_shot.utils import mkdir, rmdir
ModuleNotFoundError: No module named 'few_shot.utils'
I am getting this error, can anyone help me with this ...
Line 98 in 672de83
When I run 'python scripts/prepare_omniglot.py', the above error appeared, so I cannot import DATA_PATH from config, who knows how to deal with this error? Thank you very much.
Hi, thanks for the great repo!
In MAML 2nd order training, what values do you use for the number of inner loop iterations for mini-Imagenet dataset? The paper uses 1 and 5 iterations for Omniglot and mini-Imagenet respectively in the meta-training stage. Here, in the default arguments (https://github.com/oscarknagg/few-shot/blob/master/experiments/maml.py), inner-train-steps is set to 1. I was curious to know if 1 iteration during training would still result in good performance for mini-Imagenet.
Hi,
Thank you for sharing the repo. In the derived class of the integration operation callbacks
. I noticed that the self.model
was called in 'ModelCheckpoint'.
But the ONLY place where the self.model
has been updated is set_model
in the base class. (Note that the fitting function like proto_net_episode
is using another variable which is defined here)
This bug causes the model saved here not to be a trained model. (There's an error here, too, but it's not called in ProtoNet.)
My suggested solution is don't make the model be a member variable of the base class(like self.model
).
Anyway, Callback
's derived way of integrating operations is still commendable. Thank you.
Tests (optional)
After adding the datasets run pytest in the root directory to run all tests.
In line 84 few_shot/maml.py
loss.backward(retain_graph=True)
I don't understand the usage of this sentence. I wonder it aims to calculate the gradient of task_val_loss towards the initial parameters, but it has not been used in the latter codes. Could you explain the meaning of this sentence? Thanks a lot!
First of all thx a lot for posting this very interesting work.
"mtrand.pyx", line 1120, in mtrand.RandomState.choice
ValueError: a must be greater than 0
In running the maml.py I incurred in the error above. I believe is a numpy error... How do you get around it?
Thanks a lot,
I cloned your repository and then ran pip install -r requirements.txt
. Downloaded the minImageNet and put in the directory ./data/miniImageNet/images
. Changed the DATA_PATH
in config.py
to "./data/"
.
But when I run the command python3 scripts/prepare_mini_imagenet.py
while staying in the root directory I get the error :-
Traceback (most recent call last):
File "scripts/prepare_mini_imagenet.py", line 16, in <module>
from config import DATA_PATH
ModuleNotFoundError: No module named 'config'
@oscarknagg could you tell where the error is.
It seems to me that loss.backward(retain_graph= True) is computing the gradients, but the code is computing it again a few lines beneath it using gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph), so I feel like some duplication of work is being done here.
Again I am not sure If I am correct, any response would be appreciated !
Thanks for your time
File "matching_nets.py", line 76, in
background = dataset_class('background')
File "C:\Users\sjm6kor\Desktop\one_shot\few-shot-master\experiments\datasets.py", line 113, in init
self.unique_characters = sorted(self.df['class_name'].unique())
File "C:\Users\sjm6kor\AppData\Roaming\Python\Python36\site-packages\pandas\core\frame.py", line 2995, in getitem
indexer = self.columns.get_loc(key)
File "C:\Users\sjm6kor\AppData\Roaming\Python\Python36\site-packages\pandas\core\indexes\base.py", line 2899, in get_loc
return self._engine.get_loc(self._maybe_cast_indexer(key))
File "pandas/_libs/index.pyx", line 107, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 131, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1607, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1614, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'class_name'
Thanks for the wonderful repository and the explanation.
Could you please help me understanding the below queries.
theta_common.model
and save the model.theta_common.model
to get the fine_tuned.model
and save the model.fine_tuned.model
should be able to predict total of 11 classes.Thank you,
KK
Hi,
Thank you very much for sharing this repository, it helps quick try.
But so far I'm struggling to avoid OOM below.
Is there any clue suppressing memory use?
$ python proto_nets.py --dataset omniglot
omniglot_nt=1_kt=60_qt=5_nv=1_kv=5_qv=1
Indexing background...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19280/19280 [00:00<00:00, 281958.22it/s]
Indexing evaluation...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13180/13180 [00:00<00:00, 302261.60it/s]
Training Prototypical network on omniglot...
Begin training...
Epoch 1: 2%|██▎ | 2/100 [00:06<06:29, 3.98s/it, loss=57.9, categorical_accuracy=0.35]Traceback (most recent call last):
File "proto_nets.py", line 129, in <module>
'distance': args.distance},
File "/home/me/lab/few-shot/ew-shot/few_shot/train.py", line 113, in fit
loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
File "/home/me/lab/few-shot/few_shot/proto.py", line 67, in proto_net_episode
loss.backward()
File "/home/me/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/me/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 316.50 MiB (GPU 0; 10.91 GiB total capacity; 9.70 GiB already allocated; 95.38 MiB free; 245.36 MiB cached)
In line 126 of few_shot/maml.py:
meta_batch_loss.backward()
however, the graph for meta_batch_loss does not use model parameters. Each loss used to get meta_batch_loss was obtained using fast_weights. So while doing backwards over meta_batch_loss, gradients would be obtained over fast_weights not model parameters.
It may be corrected by separately doing autograd.grad
over model.named_parameters().values()
for each of the loss and then summing it up, and then using hook to update the model parameters.
What do you think?
I know pytest which is a package,but how to use it?
Hi Oscar,
I ran into this issue while running your code, it points some error in python source code files.
I suspect that you are on different (<3.7) python version, can you share your version?
(venv) kgarg8@edsger:~/kgarg8-workspace/few-shot$ python -m experiments.proto_nets --dataset omniglot --k-test 5 --n-test 1 omniglot_nt=1_kt=60_qt=5_nv=1_kv=5_qv=1 Indexing background... 0it [00:00, ?it/s] Traceback (most recent call last): File "/home/kgarg8/kgarg8-workspace/few-shot/venv/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 3078, in get_loc return self._engine.get_loc(key) File "pandas/_libs/index.pyx", line 140, in pandas._libs.index.IndexEngine.get_loc File "pandas/_libs/index.pyx", line 162, in pandas._libs.index.IndexEngine.get_loc File "pandas/_libs/hashtable_class_helper.pxi", line 1492, in pandas._libs.hashtable.PyObjectHashTable.get_item File "pandas/_libs/hashtable_class_helper.pxi", line 1500, in pandas._libs.hashtable.PyObjectHashTable.get_item KeyError: 'class_name'
Thank you for sharing the nice repo.
But ProtoNet's accuracy on the miniImageNet was about 2% lower than the original published in any setup.
My question is: Are there any techniques/tricks/tuning methods for improving it? (But keep the methods(ProtoNet) in the paper)
(Notice that you said in your blog: as the authors provided the full set of hyperparameters, hence .... That's true)
Thank you very much.
Hi, thank you for your great implementation!
I found that your Keras-like fit function is elegant, but as the evaluation code is implemented as a callback function after each epoch, it is unclear how to load the trained model and check the performance of it.
Can you share your evaluation code?
Hello, I followed your steps and the final result was very bad. Run experiments/experiments.txt and run experiments/proto_nets.py
The link to the mini-imagenet google drive requires access. How can I get this? If I can't, and download the data from, say, Kaggle, how should I arrange the data so that the scripts will work?
Hi, I am new to python.
Can any one help to make this code sample work?
I want to try it out the "matching_nets.py" in my requirement.
My notebook is lenovo Carbon X1 without GPU embedded.
My environment using Anaconda Command Prompt.
I got following errors
File "matching_nets.py", line 18, in
assert torch.cuda.is_available()
AssertionError
Following is the package list
Package Version
atomicwrites 1.2.1
attrs 18.2.0
backcall 0.1.0
bleach 3.0.2
certifi 2020.4.5.1
cffi 1.14.0
cloudpickle 1.3.0
colorama 0.4.3
cycler 0.10.0
cytoolz 0.10.1
dask 2.14.0
decorator 4.4.2
defusedxml 0.5.0
entrypoints 0.2.3
graphviz 0.10.1
imageio 2.8.0
ipykernel 5.1.0
ipython 7.1.1
ipython-genutils 0.2.0
ipywidgets 7.4.2
jedi 0.13.1
Jinja2 2.10
jsonschema 2.6.0
jupyter 1.0.0
jupyter-client 5.2.3
jupyter-console 6.0.0
jupyter-core 4.4.0
kiwisolver 1.2.0
MarkupSafe 1.1.1
matplotlib 3.2.1
mistune 0.8.4
mkl-fft 1.1.0
mkl-service 2.3.0
more-itertools 4.3.0
nbconvert 5.4.0
nbformat 4.4.0
networkx 2.4
notebook 5.7.0
numpy 1.18.2
olefile 0.46
pandas 0.23.4
pandocfilters 1.4.2
parso 0.3.1
pexpect 4.6.0
pickleshare 0.7.5
Pillow 7.1.1
pip 20.0.2
pluggy 0.8.0
prometheus-client 0.4.2
prompt-toolkit 2.0.7
ptyprocess 0.6.0
py 1.7.0
pycparser 2.20
Pygments 2.2.0
pyparsing 2.4.6
pytest 3.9.3
python-dateutil 2.8.1
pytz 2018.7
PyWavelets 1.1.1
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 17.1.2
qtconsole 4.4.2
scikit-image 0.16.2
scipy 1.4.1
Send2Trash 1.5.0
setuptools 46.1.3.post20200325
six 1.14.0
terminado 0.8.1
testpath 0.4.2
toolz 0.10.0
torch 1.4.0
torch-nightly 1.2.0.dev20190723
torchvision 0.5.0
tornado 6.0.4
tqdm 4.28.1
traitlets 4.3.2
wcwidth 0.1.7
webencodings 0.5.1
wheel 0.34.2
widgetsnbextension 3.4.2
wincertstore 0.2
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.