oke-aditya / pytorch_cnn_trainer Goto Github PK
View Code? Open in Web Editor NEWA Simple but Powerful CNN Trainer For PyTorch
Home Page: https://oke-aditya.github.io/pytorch_cnn_trainer/
License: GNU General Public License v3.0
A Simple but Powerful CNN Trainer For PyTorch
Home Page: https://oke-aditya.github.io/pytorch_cnn_trainer/
License: GNU General Public License v3.0
Currently, we return metrics in the train_step and val_step.
Maybe we can use them to create a dictionary of a dictionary containing a list
history = {"train" : {"top1_acc" : [] , "top5_acc" : [], "loss" : [] }, "val" : {"top1_acc" : [] , "top5_acc" : [], "loss" : [] }}
Plus we can provide plotting function for this history using matplotlib too.
Describe the bug
I 'm training model with 2 class. Error appears at line 61 in pytorch_cnn_trainer/utils.py
. By default, maxk
always equal 5 because topk
is fixed (1,5)
in train_step
and val_step
function. output
variable shape is 32x4
so result in error RuntimeError: invalid argument 5: k not in range for dimension
59 maxk = max(topk)
60 batch_size = target.size(0)
61 _, pred = output.topk(maxk, 1, True, True)
62 pred = pred.t()
63 correct = pred.eq(target.view(1, -1).expand_as(pred))
To Reproduce
Steps to reproduce the behavior:
Just train model using engine.fit
with 2 classes
Expected behavior
training process works with arbitrary number of output class
Desktop (please complete the following information):
Additional context
PyTorch 1.6 is round the corner.
It is very simple to add mixed precision training with it. Code can be tested on Kaggle.
Support this feature with fp16 : bool = False
parameter in the engine.
For better documentation, testing and packaging. We need to shift to the Python Template created in this repo
With #1 we can simply run tests on these sanity checkers instead of training !!
Write a test for it and add to the tests folder. An almost complete sample of it is in examples test_torchvision.py
file
I guess this feature should be bigger.
It would be really nice to see how CNN Learns every epoch. We can show some of the mistakes, some of the correct classifications too.
This can lead to step by step understanding of how CNN learnt.
Some Ideas that can be used.
As per NVIDIA we should turn off profilers.
Linked to #20
Describe the solution you'd like
A simple utlity function to turn it off
Describe alternatives you've considered
Completely optional step
Additional context
Should help in faster training.
Should take some time.
We can add an argument while creating models, num_labels: int = 1
. If the user specifies that, then simply we can create an extra dense layer on with the top and return it in the model.
Also, we need to slightly alter the engine to support this. We will get two outs, o1, o2. We would need to compute metrics for both and provide other functions too.
Either we can use the same train_step and val_step or use different too.
Try to integrate this LR finder repo. It has just one file https://github.com/davidtvs/pytorch-lr-finder/blob/master/torch_lr_finder/lr_finder.py,
Simply we need to edit this file at some places to support native amp with torch 1.6.
Also we need to edit at some places where it supports torch < 1.1. PyTorch CNN Trainer repo supports from torch 1.6 +
https://github.com/davidtvs/pytorch-lr-finder
Looks simple, can be done and tested too.
target
attribute has been defined but never used.
pytorch_cnn_trainer/pytorch_cnn_trainer/dataset.py
Lines 119 to 124 in ec60c87
df
will not always have column image_id
. It needs to be specified by the user just like target
.
Need to add extension like .png
or .jpg
to img_path
since img_name
only contains id of the image.
Sometimes, we only get train.csv
and test.csv
, so it is best to split the train.csv
into train_set
and valid_set
by default in CSVDataset
like how it is in create_folder_dataset
.
User can split it into train_set
and valid_set
after creating CSVDataset
like below but it is good to have it out of the box.
split = 0.8
complete_dataset = CSVDataset(df, data_dir, target, transform)
train_split = len(complete_dataset) * split
valid_split = len(complete_dataset) * (1 - split)
train_set, valid_set = torch.utils.data.random_split(
complete_dataset, [train_split, valid_split]
)
Lot of new Features were added. We need docs and examples for.
train_step
and fit()
train_step()
and fit()
We can shift them to docs with some extra explanations as well !!
Simple, recreate this API for PyTorch lightning.
It will solve distributed training issues as well.
We need to remove amp
as It would be handled by Pytorch Lightning.
We can make use of metrics package as well.
I would rather like to finish this myself. Do let me know if anyone else is interested.
Add Gradient Penalty feature for Mixed Precision training.
Simple one, we have already for without amp trianing.
Code for this can be found here
We can simply use **kwargs as argument and reduce the passing of arguments.
Simple issue which I will leave for open-source contributors.
This is natively supported by PyTorch in nn.utils.swa
here.
They have also given a small example
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
>>> T_max=300)
>>> swa_start = 160
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
>>> for i in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if i > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
We need to add this with the AMP feature as well.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.