Coder Social home page Coder Social logo

oscarknagg / few-shot Goto Github PK

View Code? Open in Web Editor NEW
1.2K 13.0 242.0 16.54 MB

Repository for few-shot learning machine learning projects

License: MIT License

Python 100.00%
machine-learning pytorch few-shot-learning research meta-learning maml omniglot miniimagenet

few-shot's People

Contributors

oscarknagg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

few-shot's Issues

question of class MatchingNetwork code

Do I need to rewrite the forward function for matching net work myself and why is it a pass in the code, or it's suggested to call with model.f(), model.g()?

why number of classes are not the same in train and test?

I used to believe in k-way-n-shot few-shot learning, k and n (number of classes and samples from each class respectively) must be the same in train and test phases. But you uses different numbers in the train and test phase (60 for train and 5 for test):

parser.add_argument('--dataset')
parser.add_argument('--distance', default='l2')
parser.add_argument('--n-train', default=1, type=int)
parser.add_argument('--n-test', default=1, type=int)
parser.add_argument('--k-train', default=60, type=int)
parser.add_argument('--k-test', default=5, type=int)
parser.add_argument('--q-train', default=5, type=int)
parser.add_argument('--q-test', default=1, type=int)

Are we allowed to do so?

How train with a new dataset

Thank you very much for sharing your code.

I want to train models with a dataset other than miniImageNet and Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?

Thanks a lot.

Metrics about ProtoNet's accuracy

On the 5-way 1-shot task of Omniglot there are two metrics in the CSV results file: categorical_accuracy and val_1-shot_5-way_acc.

I want to confirm which metric is used to evaluate the performace of the model on the dataset: categorical_accuracy or

val_1-shot_5-way_acc?

Why using dummy data to calculate loss and update the meta model in MAML?

Thanks for sharing this few-shot learning repository, which help me a lot.

I am curious and confused why using dummy data to calculate loss and update the meta model in line 110 of file few_shot/maml.py, here I enclose the code:

# Dummy pass in order to create `loss` variable
# Replace dummy gradients with mean task gradients using hooks
logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
loss.backward()
optimiser.step()

As I understand it, the test data should be applied to update the model, using dummy data to train the model seems unreasonable. I will be appreciated it if someone can answer this question.

KeyError: 'class_name'

when i train the model ,it occurd KeyError: 'class_name'
Traceback (most recent call last):
File "/home/cc/anaconda3/envs/HowToTrainYourMAMLPytorch-master/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2889, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 97, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1675, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1683, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'class_name'
please tell me methods of resolution,thank you very much

the size of output of the MatchingNetwork.encoder

Hi, thank you for a well-packed blog post and implementation on meta-learning!

I have a question about the implementation of the matching network.

I see that you don't resize Omniglot image, and let an example has a size of 105x105.
This results in the runtime error in the following line, because the input size of LSTM is 64,but the size of 'embedding' is 2304(after self.encoder). The size becomes 64 if the original image is resized to 28x28.

support, _, _ = model.g(support.unsqueeze(1))

This implementation also resizes Omniglot into 28 * 28.

Do you not get this runtime error with the current version of the code? Should I resize image in this part?

def __getitem__(self, item):

two questions

Dear Oscar Knagg,
Firstly, I thank for your helpful codes. However, I tried to download your scripts, run again on my computer and raised an error at

few_shot/core.py", line 80, in iter
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
ValueError: a must be greater than 0

Besides, in the maml, I see

order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set)
or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights).

but the corresponding program is:
`python
if order == 1:
if train:
sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
for k in task_gradients[0].keys()}
hooks = []
for name, param in model.named_parameters():
hooks.append(
param.register_hook(replace_grad(sum_task_gradients, name))
)

    model.train()
    optimiser.zero_grad()
    # Dummy pass in order to create `loss` variable
    # Replace dummy gradients with mean task gradients using hooks
    logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
    loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
    loss.backward()
    optimiser.step()

    for h in hooks:
        h.remove()

return torch.stack(task_losses).mean(), torch.cat(task_predictions)

elif order == 2:
model.train()
optimiser.zero_grad()
meta_batch_loss = torch.stack(task_losses).mean()

if train:
    meta_batch_loss.backward()
    optimiser.step()

`
aren't the "order==2" part doing the "update meta-learner weights with gradients of the updated weights on the query set" job?

and I can't understand your code in "order==1" part. Can you explain it with more details again?

Any response would be appreciated! Thanks for your time!

setting of protonet

why the parameter setting of protonet for 5way 1 shot is 20 k train and 1 shot?

Protonets results on miniImageNet

Hi,

Can you please share the hyperparameters used to reproduce results on miniImageNet. I tried using the parameters from experiments.txt but couldn't get these numbers.

Thanks.

Possible solution to LSTM concatenation problem

Hi Oscar,
firstly, thank you for sharing your code with all of us. I noticed you encountered the same trouble I had while I was trying to implement the f Full Context Embedding following the original paper description.

# h_hat, c = self.lstm_cell(queries, (torch.cat([h, readout], dim=1), c))

I build a LSTM model from scratch that could allow you to fix this problem.
You can find the model here: https://github.com/rtorrisi/LSTM-Custom-InOut-Hidden-Size
Let me know your thoughts.

Kind regards,
Riccardo

Wrong number of classes when rotating

There is a bug related to class names. Using your code in scripts/prepare_omniglot.py, I got class names like 'Alphabet_of_the_Magi.0.character03'. But I think 'Alphabet_of_the_Magi.0.character03' and 'Alphabet_of_the_Magi.90.character03' have same classes because they are augmented data which have same labels(classes).
So if you change the following line
https://github.com/oscarknagg/few-shot/blob/master/few_shot/datasets.py#L80
to
alphabet = root.split('/')[-2].split('.')[0]

EvaluateFewShot in MAML

In the code, the model is passed to Evaluation phase by using set_model() function:

    def set_model(self, model):
        for callback in self.callbacks:
            callback.set_model(model)

Does it work correctly if we only copy the weights/state_dict of the model and optimizer? (In the case that we save the model weight and then finetune with new tasks) Is there any loss information when we save the model state_dict?

few_shot.utils is not found

!python /content/few-shot/scripts/prepare_omniglot.py

Traceback (most recent call last):
File "/content/few-shot/scripts/prepare_omniglot.py", line 20, in
from few_shot.utils import mkdir, rmdir
ModuleNotFoundError: No module named 'few_shot.utils'

I am getting this error, can anyone help me with this ...

Query on number of inner loop iterations in MAML

Hi, thanks for the great repo!

In MAML 2nd order training, what values do you use for the number of inner loop iterations for mini-Imagenet dataset? The paper uses 1 and 5 iterations for Omniglot and mini-Imagenet respectively in the meta-training stage. Here, in the default arguments (https://github.com/oscarknagg/few-shot/blob/master/experiments/maml.py), inner-train-steps is set to 1. I was curious to know if 1 iteration during training would still result in good performance for mini-Imagenet.

BUG: model CANNOT be shared in a derived class

Hi,

Thank you for sharing the repo. In the derived class of the integration operation callbacks. I noticed that the self.model was called in 'ModelCheckpoint'.

But the ONLY place where the self.model has been updated is set_model in the base class. (Note that the fitting function like proto_net_episode is using another variable which is defined here)

This bug causes the model saved here not to be a trained model. (There's an error here, too, but it's not called in ProtoNet.)
My suggested solution is don't make the model be a member variable of the base class(like self.model).

Anyway, Callback's derived way of integrating operations is still commendable. Thank you.

Queries about the MAML implementation

In line 84 few_shot/maml.py
loss.backward(retain_graph=True)
I don't understand the usage of this sentence. I wonder it aims to calculate the gradient of task_val_loss towards the initial parameters, but it has not been used in the latter codes. Could you explain the meaning of this sentence? Thanks a lot!

Error in importing config.py

I cloned your repository and then ran pip install -r requirements.txt. Downloaded the minImageNet and put in the directory ./data/miniImageNet/images. Changed the DATA_PATH in config.py to "./data/".
But when I run the command python3 scripts/prepare_mini_imagenet.py while staying in the root directory I get the error :-

Traceback (most recent call last):
  File "scripts/prepare_mini_imagenet.py", line 16, in <module>
    from config import DATA_PATH
ModuleNotFoundError: No module named 'config'

@oscarknagg could you tell where the error is.

loss.backward(retain_graph=True) in line 131 in maml.py, is it required ?

It seems to me that loss.backward(retain_graph= True) is computing the gradients, but the code is computing it again a few lines beneath it using gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph), so I feel like some duplication of work is being done here.

Again I am not sure If I am correct, any response would be appreciated !

Thanks for your time

i am facing the below error, i have installed all the packages

File "matching_nets.py", line 76, in
background = dataset_class('background')
File "C:\Users\sjm6kor\Desktop\one_shot\few-shot-master\experiments\datasets.py", line 113, in init
self.unique_characters = sorted(self.df['class_name'].unique())
File "C:\Users\sjm6kor\AppData\Roaming\Python\Python36\site-packages\pandas\core\frame.py", line 2995, in getitem
indexer = self.columns.get_loc(key)
File "C:\Users\sjm6kor\AppData\Roaming\Python\Python36\site-packages\pandas\core\indexes\base.py", line 2899, in get_loc
return self._engine.get_loc(self._maybe_cast_indexer(key))
File "pandas/_libs/index.pyx", line 107, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 131, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1607, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1614, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'class_name'

Queries on MAML Implementation

Thanks for the wonderful repository and the explanation.

Could you please help me understanding the below queries.

  1. Lets consider a dataset of 10 classes and we train our MAML model on that.
  2. This creates a model theta_common.model and save the model.
  3. Consider a unseen class (class_11) and train/test the model theta_common.model to get the fine_tuned.model and save the model.
  4. fine_tuned.model should be able to predict total of 11 classes.
  5. Is my understanding is correct on MAML?

Thank you,
KK

Saving memory to avoid CUDA OOM with GTX-1080Ti

Hi,

Thank you very much for sharing this repository, it helps quick try.
But so far I'm struggling to avoid OOM below.

Is there any clue suppressing memory use?

  • Using Omniglot dataset.
  • Tried DataLoader num_worker=1, but it still shows error.
$ python proto_nets.py --dataset omniglot
omniglot_nt=1_kt=60_qt=5_nv=1_kv=5_qv=1
Indexing background...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19280/19280 [00:00<00:00, 281958.22it/s]
Indexing evaluation...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13180/13180 [00:00<00:00, 302261.60it/s]
Training Prototypical network on omniglot...
Begin training...
Epoch 1:   2%|██▎                                                                                                                  | 2/100 [00:06<06:29,  3.98s/it, loss=57.9, categorical_accuracy=0.35]Traceback (most recent call last):
  File "proto_nets.py", line 129, in <module>
    'distance': args.distance},
  File "/home/me/lab/few-shot/ew-shot/few_shot/train.py", line 113, in fit
    loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
  File "/home/me/lab/few-shot/few_shot/proto.py", line 67, in proto_net_episode
    loss.backward()
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/me/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 316.50 MiB (GPU 0; 10.91 GiB total capacity; 9.70 GiB already allocated; 95.38 MiB free; 245.36 MiB cached)

Implementation of 2nd order gradient in maml

In line 126 of few_shot/maml.py:

meta_batch_loss.backward()

however, the graph for meta_batch_loss does not use model parameters. Each loss used to get meta_batch_loss was obtained using fast_weights. So while doing backwards over meta_batch_loss, gradients would be obtained over fast_weights not model parameters.

It may be corrected by separately doing autograd.grad over model.named_parameters().values() for each of the loss and then summing it up, and then using hook to update the model parameters.

What do you think?

KeyError: 'class_name' in python3.7/.../base.py file

Hi Oscar,

I ran into this issue while running your code, it points some error in python source code files.
I suspect that you are on different (<3.7) python version, can you share your version?

(venv) kgarg8@edsger:~/kgarg8-workspace/few-shot$ python -m experiments.proto_nets --dataset omniglot --k-test 5 --n-test 1 omniglot_nt=1_kt=60_qt=5_nv=1_kv=5_qv=1 Indexing background... 0it [00:00, ?it/s] Traceback (most recent call last): File "/home/kgarg8/kgarg8-workspace/few-shot/venv/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 3078, in get_loc return self._engine.get_loc(key) File "pandas/_libs/index.pyx", line 140, in pandas._libs.index.IndexEngine.get_loc File "pandas/_libs/index.pyx", line 162, in pandas._libs.index.IndexEngine.get_loc File "pandas/_libs/hashtable_class_helper.pxi", line 1492, in pandas._libs.hashtable.PyObjectHashTable.get_item File "pandas/_libs/hashtable_class_helper.pxi", line 1500, in pandas._libs.hashtable.PyObjectHashTable.get_item KeyError: 'class_name'

About ProtoNet's accuracy on miniImageNet

Thank you for sharing the nice repo.

But ProtoNet's accuracy on the miniImageNet was about 2% lower than the original published in any setup.
My question is: Are there any techniques/tricks/tuning methods for improving it? (But keep the methods(ProtoNet) in the paper)

(Notice that you said in your blog: as the authors provided the full set of hyperparameters, hence .... That's true)
Thank you very much.

Best practice for evaluation code

Hi, thank you for your great implementation!

I found that your Keras-like fit function is elegant, but as the evaluation code is implemented as a callback function after each epoch, it is unclear how to load the trained model and check the performance of it.

Can you share your evaluation code?

Result question

Hello, I followed your steps and the final result was very bad. Run experiments/experiments.txt and run experiments/proto_nets.py

mini-imagenet access

The link to the mini-imagenet google drive requires access. How can I get this? If I can't, and download the data from, say, Kaggle, how should I arrange the data so that the scripts will work?

Fail to run the matching_nets.py

Hi, I am new to python.
Can any one help to make this code sample work?

I want to try it out the "matching_nets.py" in my requirement.
My notebook is lenovo Carbon X1 without GPU embedded.
My environment using Anaconda Command Prompt.

I got following errors
File "matching_nets.py", line 18, in
assert torch.cuda.is_available()
AssertionError

Following is the package list
Package Version


atomicwrites 1.2.1
attrs 18.2.0
backcall 0.1.0
bleach 3.0.2
certifi 2020.4.5.1
cffi 1.14.0
cloudpickle 1.3.0
colorama 0.4.3
cycler 0.10.0
cytoolz 0.10.1
dask 2.14.0
decorator 4.4.2
defusedxml 0.5.0
entrypoints 0.2.3
graphviz 0.10.1
imageio 2.8.0
ipykernel 5.1.0
ipython 7.1.1
ipython-genutils 0.2.0
ipywidgets 7.4.2
jedi 0.13.1
Jinja2 2.10
jsonschema 2.6.0
jupyter 1.0.0
jupyter-client 5.2.3
jupyter-console 6.0.0
jupyter-core 4.4.0
kiwisolver 1.2.0
MarkupSafe 1.1.1
matplotlib 3.2.1
mistune 0.8.4
mkl-fft 1.1.0
mkl-service 2.3.0
more-itertools 4.3.0
nbconvert 5.4.0
nbformat 4.4.0
networkx 2.4
notebook 5.7.0
numpy 1.18.2
olefile 0.46
pandas 0.23.4
pandocfilters 1.4.2
parso 0.3.1
pexpect 4.6.0
pickleshare 0.7.5
Pillow 7.1.1
pip 20.0.2
pluggy 0.8.0
prometheus-client 0.4.2
prompt-toolkit 2.0.7
ptyprocess 0.6.0
py 1.7.0
pycparser 2.20
Pygments 2.2.0
pyparsing 2.4.6
pytest 3.9.3
python-dateutil 2.8.1
pytz 2018.7
PyWavelets 1.1.1
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 17.1.2
qtconsole 4.4.2
scikit-image 0.16.2
scipy 1.4.1
Send2Trash 1.5.0
setuptools 46.1.3.post20200325
six 1.14.0
terminado 0.8.1
testpath 0.4.2
toolz 0.10.0
torch 1.4.0
torch-nightly 1.2.0.dev20190723
torchvision 0.5.0
tornado 6.0.4
tqdm 4.28.1
traitlets 4.3.2
wcwidth 0.1.7
webencodings 0.5.1
wheel 0.34.2
widgetsnbextension 3.4.2
wincertstore 0.2

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.