Comments (2)
Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as hidden_dim
) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).
from pytorch-cifar.
Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as
hidden_dim
) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).
Hello, I face the same problem. This is my code:
device = torch.device("cuda")
full_model = ResNet18().to(device)
#print(full_model)
checkpoint = torch.load("./checkpoint/ckpt.pth")
print(checkpoint)
full_model.load_state_dict(checkpoint['net'],strict=False) #,strict=False
full_model.to(device)
full_model.eval()
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=transform_test, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = full_model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(test_loader), 'Acc: %.3f%% (%d/%d)'
% (100.*correct/total, correct, total))
from pytorch-cifar.
Related Issues (20)
- A bug in dla.py
- pytorch-cifar复现问题 HOT 4
- MobileNetV2 training does not converge HOT 1
- The test set is being used as validation set HOT 6
- 对于精度有疑问 HOT 2
- pre-train weights
- request: epochs numbers to converge in readme.md HOT 1
- ResNet18 performs much better than expected! HOT 5
- how to train my own dataset and classes? HOT 1
- [Question] efficientnet performance HOT 1
- Error when loading the weights on CPU but trained on GPU HOT 1
- Errors when testing on CPU HOT 11
- checkpoint
- ValueError: not enough values to unpack (expected 2, got 0) HOT 3
- load state_dict HOT 1
- Overfitting on ResNet18 HOT 2
- A problem in ShuffleNet
- How long will it take you to train the cifar10 model? HOT 1
- Overfitting
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-cifar.