Coder Social home page Coder Social logo

orobix / prototypical-networks-for-few-shot-learning-pytorch Goto Github PK

View Code? Open in Web Editor NEW
976.0 976.0 210.0 201 KB

Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch

License: MIT License

Python 100.00%
cnn prototypical-networks python pytorch

prototypical-networks-for-few-shot-learning-pytorch's People

Contributors

belerico avatar dnlcrl avatar rcy17 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

prototypical-networks-for-few-shot-learning-pytorch's Issues

How do I make a prediction?

Thank you for your work!

I've trained a model for a few epochs, and now I'd like to make predictions with it. I load it:

model = ProtoNet().cuda()
model.load_state_dict(torch.load('./output/best_model.pth'))

I load 15 labeled data points, for a total of 3 labels:

# x.size() -> (15)
# y.size() -> (15, 64, 64)
x, y = load_data()

I load a single datapoint I want to predict

to_predict = torch.Tensor(1, 64, 64)

I now would like to few-shot train on 5 examples per class and then predict a class for my to_predict. How do I go about that?

Questions about dataset

Hi
I have a question about image loading in omniglot_dataset.py

Why use 1 - array & why use transpose and reshape as below?

https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/src/omniglot_dataset.py#L179-L180

Thanks

Sampling without replacement

In the original paper in "Algorithm 1", they mention that each batch is sampled "without replacement":

... RANDOMSAMPLE(S, N) denotes a set of N elements chosen uniformly at
random from set S, without replacement.

where as your sampler class clearly samples with replacement as you even pass the number of iterations as an argument to the class constructor. may i ask why?

Error in dataset

paths, self.y = zip(*[self.get_path_label(pl) for pl in range(len(self))])

what is len(self)?

dimension error

Running python train.py --cuda -nsTr 1 -nsVa 1 gives me a runtime error in euclidean_dist ("dimension out of range (expected to be in range of [-1, 0], but got 1)")

It seems the code is trying to compute a distance between query_samples (a matrix size 360x64) and prototypes (a vector length 60)- so yeah, I can see why it's producing this error. But why are query_samples and prototypes these sizes, to begin with? Shouldn't they have the same value in size(1)?

Prediction new problem

Hi guys! I am at this point and I do not just understand it, I already have the trained model. For the test, what I understand that should be done is to calculate the embedding of my sample and see which centroid is closest to classify ... If I have trained with 5000 classes I do not understand at all as in the test phase it is necessary to pass a set of support and another of query.

Taking the implementation of @ale316 , predict (support_x, support_y, query_x, query_y = None), we will pass support sets of the training set and querys of which we do not know its y (therefore we equate it to None) and ... as takes into account the 5000 classes if the support only includes 30 (to say some number)

My idea is to create all the training embeddings and then generate the centroids -> After that generate the embeddings for the test samples and get the class with the distance euclidean to the centroids... I am correct? But you dont do that,

Sorry for the spam here, but for posterity:

I wrote a function that should return predictions given:

  • a tensor support_x of size (n_support, 1024)
  • a tensor support_y of size (n_support,)
  • a tensor query_x of size (n_query, 1024)
def predict(support_x, support_y, query_x, query_y=None):
    support_x = support_x.to('cpu')
    support_y = support_y.to('cpu')
    query_x = query_x.to('cpu')

    classes = torch.unique(support_y)
    n_classes = len(classes)
    n_query = len(query_x)

    # get a list of tensors of support_y for each class
    support_idxs = list(map(lambda c: support_y.eq(c).nonzero().squeeze(1), classes))

    # take the mean of tensors for each class to create a centroid
    prototypes = torch.stack([support_x[idx_list].mean(0) for idx_list in support_idxs])

    # finds the euclidean distances between each query_x and each centroid
    dists = euclidean_dist(query_x, prototypes)

    # run it through softmax
    log_p_y = F.log_softmax(-dists, dim=1)

    # lists the idx (label) of the closest centroid for each query_x
    _, y_hat = log_p_y.max(1)
    labels = [classes[i] for i in y_hat.squeeze()]

    return labels

Now, I still have some doubts:

  • Is this correct? Edit: yes it is
  • Why do we even softmax the distances, instead of just taking the min?

Originally posted by @MarioProjects in #8 (comment)

About Miniimagenet

Do you mind share your code about miniimagenet dataloader and train? It will be appreciated.

about the result

I run your code with default settings, but get very pool results of accurcy (about 53%). What could be the reason for this?

error at accuracy

Dear author

Thank you for your carefully written code.
I re-use your some codes, and I found out the error

please check the line 84 in prototypical_loss.py
I think y_hat should be sequeezed with squeeze()

y_hat and target_inds.squeeze() look like:

y_hat = torch.tensor([[0],[1],[2],[0],[4]])
target_inds.squeeze() = torch.tensor([0, 1, 2, 3, 4])

In this case,

y_hat.eq(target_inds.squeeze()).float()

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

In this case, accuracy is 0.2

It should be tensor([1., 1., 1., 0., 1.]).
In this case, accuracy is 0.8

error of omniglot_dataset.py

When i run the programe: how can i solve it?

opt = get_parser().parse_args()
mode = "train"
OmniglotDataset(mode=mode, root=opt.dataset_root)

== Downloading https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt
Traceback (most recent call last):

File "", line 1, in
dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)

File "C:\Users\lenovo\Desktop\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py", line 50, in init
self.download()

File "C:\Users\lenovo\Desktop\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py", line 113, in download
with open(file_path, 'wb') as f:

OSError: [Errno 22] Invalid argument: '..\dataset\splits\vinyals\https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt'

About code of pretreatment

Thanks for share of code, but I think the preprocessing code is too complex to understand, which could write much more simple.

Loss Backpropogation

During the learning stage, the loss isn't backpropagated to the model and I am obtaining the same accuracy and loss even after training for a huge number of epochs.

How to train a new dataset?

Thank you very much for sharing your code.

I want to train models with a dataset other than miniImageNet and Omniglot. Could you guide me how to arrange dataset (main folder , sub-folder), and how to train models on a new dataset?

Thanks a lot.

query in loading images

Hi.

Thanks for the repo. I have a couple of queries in the load_img function.

def load_img(path, idx):
path, rot = path.split(os.sep + 'rot')
if path in IMG_CACHE:
x = IMG_CACHE[path]
else:
x = Image.open(path)
IMG_CACHE[path] = x
x = x.rotate(float(rot))
x = x.resize((28, 28))
shape = 1, x.size[0], x.size[1]
x = np.array(x, np.float32, copy=False)
x = 1.0 - torch.from_numpy(x)
x = x.transpose(0, 1).contiguous().view(shape)
return x

  1. Is it necessary to do x = 1.0 - torch.from_numpy(x) ? I understand this is to have 1s in the region of interest (the character) but does it really help?

  2. Why do you take the transpose (rotates the image again) at the end?

Thanks!

Testing Error

When testing, the network calculates loss using the number of supports specified for the training regime. This appears to be an error, but perhaps I've missed something?

_, acc = loss_fn(model_output, target=y, n_support=opt.num_support_tr)

To me, this line should be

_, acc = loss_fn(model_output, target=y, n_support=opt.num_support_val)

How can I get your test acc?

I use your 'python train.py --cuda' to evaluate your code.
But I can only get 65% test acc.
Can you tall me what happened?
Thank you very much

FileNotFoundError


FileNotFoundError Traceback (most recent call last)
in
252
253 if name == 'main':
--> 254 main()

in main()
208 init_seed(options)
209
--> 210 tr_dataloader = init_dataloader(options, 'train')
211 val_dataloader = init_dataloader(options, 'val')
212 # trainval_dataloader = init_dataloader(options, 'trainval')

in init_dataloader(opt, mode)
47
48 def init_dataloader(opt, mode):
---> 49 dataset = init_dataset(opt, mode)
50 sampler = init_sampler(opt, dataset.y, mode)
51 dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)

in init_dataset(opt, mode)
23
24 def init_dataset(opt, mode):
---> 25 dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)
26 n_classes = len(np.unique(dataset.y))
27 if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:

E:\学习\jupyter\prototypical net\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py in init(self, mode, root, transform, target_transform, download)
53 raise RuntimeError(
54 'Dataset not found. You can use download=True to download it')
---> 55 self.classes = get_current_classes(os.path.join(
56 self.root, self.splits_folder, mode + '.txt'))
57 self.all_items = find_items(os.path.join(

E:\学习\jupyter\prototypical net\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py in get_current_classes(fname)
159
160 def get_current_classes(fname):
--> 161 with open(fname) as f:
162 classes = f.read().replace('/', os.sep).splitlines()
163 return classes

FileNotFoundError: [Errno 2] No such file or directory: '..\dataset\splits\vinyals\train.txt'

Question about batching function

Hello,

I have been trying to implement prototypical networks on a different dataset (audio), and I have been having some difficulty with training, my loss seems to be reducing very slowly.

I had a question about your batching function. If we assume there are 32 query points in a batch, and 3-5 support points per class.

Is each query point getting a label (1-32) and is this arbitrary? Or are datapoints given labels based on the whole training set?
Put in another way, Does a given query point get the same label (necessarily) in different mini-batches.
If you could give me an example that would be super helpful.

Thanks,
Gautam

Something wrong when i set num_query_val to 1.

Hello, thank you for the code provided.
When I set num_query_val to 1, why val_acc and test_acc will not change, and it is always about num_query_val / classes_per_it_val, can you tell me what went wrong?
Thank you! Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.