Comments (7)
I have implemented one which got 49.1/66.9, still slightly worse than the paper's.
You could check the code if it helps.
from prototypical-networks.
How to run the code with miniImageNet dataset ? I just replace the line: default_dataset = 'miniImagenet', but it doesn't work.
so do you know how to run with miniImageNet now? I think there are no codes related to miniImagenet dataset.
from prototypical-networks.
How to run the code with miniImageNet dataset ? I just replace the line: default_dataset = 'miniImagenet', but it doesn't work.
from prototypical-networks.
I have the same problem. The re-implementation results are much lower than the reported results on miniImageNet dataset.
from prototypical-networks.
Hi Guys,
Does anybody how many training epochs and episodes/epoch was used to reproduce the paper's results ?
from prototypical-networks.
A simple modification to reproduce the results is scaling the outputs of the euclidean distance. That is,
feature_dims = 1600 # 1600 for miniimagenet, 64 for omniglot
learnable_scale = nn.Parameter(torch.FloatTensor(1).fill_(1.0), requires_grad=True)
dist = learnable_scale * euclidean_dist(x, y) / 1600
In this way, I am able to get
1-shot: 50.87%
5-shot: 68.21%
from prototypical-networks.
@bilylee Hi, I tried this way, but still got little improvement in the 5-shot scenario (specifically 67.1%).
This is a code snippet
class Convnet(nn.Module):
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
super().__init__()
self.encoder = nn.Sequential(
conv_block(x_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, z_dim),
)
self.out_channels = 1600
self.scale = nn.Parameter(torch.FloatTensor(1).fill_(1.0), requires_grad=True)
def forward(self, x):
x = self.encoder(x)
return x.view(x.size(0), -1)
def loss(self, data, num_way, num_support, num_query):
p = num_support * num_way
data_shot, data_query = data[:p], data[p:]
proto = self.forward(data_shot)
proto = proto.reshape(num_support, num_way, -1).mean(dim=0)
label = torch.arange(num_way).repeat(num_query)
label = label.type(torch.cuda.LongTensor)
logits = self.scale * euclidean_metric(self.forward(data_query), proto) / self.out_channels
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
return loss, acc
from prototypical-networks.
Related Issues (20)
- ModuleNotFoundError: No module named 'protonets.utils' HOT 3
- why the training speed of an epoch is accelerating? HOT 3
- Train accuracy is lower than val accuracy HOT 1
- thank you for yourc code . but i can not download the dataset, .sh profile i can not conduct. can you tell the address of the datasets? HOT 1
- Some questions
- updated implementation HOT 1
- How to train a new dataset? HOT 5
- What is the accuracy of this repo on the Mini-ImageNet dataset? HOT 3
- Research project about PN
- Error while training
- Installation of setup under windows HOT 4
- Question for Loss
- Alternative implementation
- About calculation of loss function HOT 1
- dataset questions HOT 1
- Rotate four angles. Are these photos in four categories HOT 1
- Zero-shot learning meta-data
- An error in paper
- accuracy when training with 1 query set
- What if it's an event
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from prototypical-networks.