Coder Social home page Coder Social logo

Comments (7)

yinboc avatar yinboc commented on September 6, 2024 6

I have implemented one which got 49.1/66.9, still slightly worse than the paper's.

You could check the code if it helps.

from prototypical-networks.

DamonAtSjtu avatar DamonAtSjtu commented on September 6, 2024 4

How to run the code with miniImageNet dataset ? I just replace the line: default_dataset = 'miniImagenet', but it doesn't work.

so do you know how to run with miniImageNet now? I think there are no codes related to miniImagenet dataset.

from prototypical-networks.

Fangwq avatar Fangwq commented on September 6, 2024 2

How to run the code with miniImageNet dataset ? I just replace the line: default_dataset = 'miniImagenet', but it doesn't work.

from prototypical-networks.

PatrickZH avatar PatrickZH commented on September 6, 2024 1

I have the same problem. The re-implementation results are much lower than the reported results on miniImageNet dataset.

from prototypical-networks.

debasmitdas avatar debasmitdas commented on September 6, 2024

Hi Guys,
Does anybody how many training epochs and episodes/epoch was used to reproduce the paper's results ?

from prototypical-networks.

bilylee avatar bilylee commented on September 6, 2024

A simple modification to reproduce the results is scaling the outputs of the euclidean distance. That is,

feature_dims = 1600 # 1600 for miniimagenet, 64 for omniglot
learnable_scale = nn.Parameter(torch.FloatTensor(1).fill_(1.0), requires_grad=True)
dist = learnable_scale * euclidean_dist(x, y) / 1600

In this way, I am able to get
1-shot: 50.87%
5-shot: 68.21%

from prototypical-networks.

d-li14 avatar d-li14 commented on September 6, 2024

@bilylee Hi, I tried this way, but still got little improvement in the 5-shot scenario (specifically 67.1%).
This is a code snippet

class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )
        self.out_channels = 1600
        self.scale = nn.Parameter(torch.FloatTensor(1).fill_(1.0), requires_grad=True)

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

    def loss(self, data, num_way, num_support, num_query):
        p = num_support * num_way
        data_shot, data_query = data[:p], data[p:]

        proto = self.forward(data_shot)
        proto = proto.reshape(num_support, num_way, -1).mean(dim=0)

        label = torch.arange(num_way).repeat(num_query)
        label = label.type(torch.cuda.LongTensor)

        logits = self.scale * euclidean_metric(self.forward(data_query), proto) / self.out_channels
        loss = F.cross_entropy(logits, label)
        acc = count_acc(logits, label)

        return loss, acc

from prototypical-networks.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.