Coder Social home page Coder Social logo

spandan-madan / pytorch_fine_tuning_tutorial Goto Github PK

View Code? Open in Web Editor NEW
271.0 271.0 63.0 90.13 MB

A short tutorial on performing fine tuning or transfer learning in PyTorch.

Python 100.00%
deep-learning image-classification pytorch-tutorials tutorial

pytorch_fine_tuning_tutorial's Introduction

mit-logo

alt_tag

Pytorch Tutorial for Fine Tuning/Transfer Learning a Resnet for Image Classification

If you want to do image classification by fine tuning a pretrained mdoel, this is a tutorial will help you out. It shows how to perform fine tuning or transfer learning in PyTorch with your own data. It is based on a bunch of of official pytorch tutorials/examples. I felt that it was not exactly super trivial to perform in PyTorch, and so I thought I'd release my code as a tutorial which I wrote originally for my research.

Highly encourage you to run this on a new data set (read main_fine_tuning.py to know which format to store your data in), but for a sample dataset to start with, you can download a simple 2 class dataset from here - https://download.pytorch.org/tutorial/hymenoptera_data.zip

All Torch and PyTorch specific details have been explained in detail in the file main_fine_tuning.py

Hope this tutorial helps you out! :)

Credits - This tutorial is built on top of mainly on 2 Pytorch tutorials - http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html and https://github.com/pytorch/examples/tree/master/imagenet.

pytorch_fine_tuning_tutorial's People

Contributors

spandan-madan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch_fine_tuning_tutorial's Issues

AttributeError: 'ResNet' object has no attribute 'save_state_dict'

torch.Size([3, 3, 224, 224])
trying epoch loss
val Loss: 0.0533 Acc: 0.7712
Training complete in 96m 49s
Best val Acc: 0.790850
returning and looping back
Traceback (most recent call last):
  File "main_fine_tuning.py", line 267, in <module>
    model_ft.save_state_dict('fine_tuned_best_model.pt')
  File "/scratch/sjn-p3/anaconda/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 518, in __getattr__
    type(self).__name__, name))
AttributeError: 'ResNet' object has no attribute 'save_state_dict'
[jalal@goku Pytorch_fine_tuning_Tutorial]$ 

>>> import torch
>>> torch.__version__
'0.4.1.post2'

error with GPU

Hello,

It is a great post. I used my custom 2-class dataset and worked well with resnet50 with CPU. However, with 1 or 2 GPU(s), there was a following error. Please advise.

Setting:
Ubuntu 14.04, python 3.5.4
pyTorch 0.2 via conda (same error with 0.1.12_2 via pip)
CUDA 8.0.61
NVIDIA GeForce 1080Ti x2


(15 , 2 ,.,.) =
-1.5081 -1.4036 -1.4036 ... -1.6476 -1.6999 -1.6476
-1.6127 -1.5081 -1.5604 ... -1.5604 -1.5081 -1.6476
-1.6127 -1.6476 -1.6476 ... -1.5779 -1.5430 -1.5081
... โ‹ฑ ...
-1.6476 -1.6302 -1.7173 ... -0.9156 -1.1421 -1.2641
-1.6650 -1.6127 -1.6476 ... -0.5495 -0.9853 -1.2467
-1.4907 -1.6127 -1.4733 ... -1.6127 -1.2293 -1.5430
[torch.FloatTensor of size 16x3x224x224]

1
0
1
1
1
1
1
1
0
1
1
0
0
0
1
0
[torch.LongTensor of size 16]

Traceback (most recent call last):
File "main.py", line 260, in
num_epochs=100)
File "main.py", line 177, in train_model
outputs = model(inputs)
File "/home/owner/anaconda3/envs/pytorch_ssd/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in call
result = self.forward(*input, **kwargs)
File "/home/owner/anaconda3/envs/pytorch_ssd/lib/python3.5/site-packages/torchvision/models/resnet.py", line 139, in forward
x = self.conv1(x)
File "/home/owner/anaconda3/envs/pytorch_ssd/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in call
result = self.forward(*input, **kwargs)
File "/home/owner/anaconda3/envs/pytorch_ssd/lib/python3.5/site-packages/torch/nn/modules/conv.py", line 237, in forward
self.padding, self.dilation, self.groups)
File "/home/owner/anaconda3/envs/pytorch_ssd/lib/python3.5/site-packages/torch/nn/functional.py", line 40, in conv2d
return f(input, weight, bias)
TypeError: argument 0 is not a Variable

ImportError: dlopen: cannot load any more object with static TLS

When I run 'python main_fine_tuning.py' command in an Ubuntu server, I got following error.
Can you provide a solution for me?

Traceback (most recent call last):
File "main_fine_tuning.py", line 11, in
import matplotlib.pyplot as plt
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/pyplot.py", line 116, in
_backend_mod, new_figure_manager, draw_if_interactive, _show = pylab_setup()
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/init.py", line 60, in pylab_setup
[backend_name], 0)
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/backend_qt5agg.py", line 16, in
from .backend_qt5 import (
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/backend_qt5.py", line 18, in
import matplotlib.backends.qt_editor.figureoptions as figureoptions
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/qt_editor/figureoptions.py", line 20, in
import matplotlib.backends.qt_editor.formlayout as formlayout
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/qt_editor/formlayout.py", line 56, in
from matplotlib.backends.qt_compat import QtGui, QtWidgets, QtCore
File "/home/hao/anaconda2/lib/python2.7/site-packages/matplotlib/backends/qt_compat.py", line 137, in
from PyQt5 import QtCore, QtGui, QtWidgets
ImportError: dlopen: cannot load any more object with static TLS

cannot enter the eval stage

Hello, I am now using the programme for the finetune of "Resnet50". And I just change the "Resnet 18" in the original programme to the "Resnet 50". But the problem is that after the first training stage, it will break automatically and does not give any error information. Anyone can give some help? Thanks a lot!

ImportError: DLL load failed: The paging file is too small for this operation to complete.

after run the main_fine_tuning.py file, i got this trace back:

Epoch 0/99
LR is set to 0.001
Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "main_fine_tuning.py", line 265, in <module>
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
    num_epochs=100)
  File "main_fine_tuning.py", line 162, in train_model
    for data in dset_loaders[phase]:
  File "C:\Users\dk12a7\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 501, in __iter__
    return _DataLoaderIter(self)
  File "C:\Users\dk12a7\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 289, in __init__
    w.start()
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\process.py", line 112, in start
    self._popen = self._Popen(self)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
    exitcode = _main(fd)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\spawn.py", line 114, in _main
    prepare(preparation_data)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\spawn.py", line 225, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
    run_name="__mp_main__")
  File "C:\Users\dk12a7\Anaconda3\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "C:\Users\dk12a7\Anaconda3\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "C:\Users\dk12a7\Anaconda3\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\dk12a7\Desktop\code classification\Pytorch_fine_tuning_Tutorial\main_fine_tuning.py", line 4, in <module>
    import torch
  File "C:\Users\dk12a7\Anaconda3\lib\site-packages\torch\__init__.py", line 80, in <module>
    from torch._C import *
ImportError: DLL load failed: The paging file is too small for this operation to complete.
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
    return Popen(process_obj)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 65, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\Users\dk12a7\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
BrokenPipeError: [Errno 32] Broken pipe

i tried to set the BATCH_SIZE =1 , but this problem still occur. Do you have any solution for this one?

vgg16

hello, this is a nice tutorial. I would like to use vgg16. it is not working on this data when i apply.

Upgrading to pytorch 0.4

Can someone make a fork, update and raise a pull request? I'm a little busy these days and won't be able to. Thanks a lot in advance :)

TypeError: argument 0 is not a Variable

Got the following error.

Traceback (most recent call last):
File "main_fine_tuning.py", line 260, in
num_epochs=100)
File "main_fine_tuning.py", line 177, in train_model
outputs = model(inputs)
File "/home/user/pytorch_python2/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 224, in call
result = self.forward(*input, **kwargs)
File "/home/user/pytorch_python2/local/lib/python2.7/site-packages/torchvision/models/resnet.py", line 139, in forward
x = self.conv1(x)
File "/home/user/pytorch_python2/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 224, in call
result = self.forward(*input, **kwargs)
File "/home/userj/pytorch_python2/local/lib/python2.7/site-packages/torch/nn/modules/conv.py", line 254, in forward
self.padding, self.dilation, self.groups)
File "/home/user/pytorch_python2/local/lib/python2.7/site-packages/torch/nn/functional.py", line 52, in conv2d
return f(input, weight, bias)
TypeError: argument 0 is not a Variable

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.