Coder Social home page Coder Social logo

oke-aditya / pytorch_cnn_trainer Goto Github PK

View Code? Open in Web Editor NEW
26.0 2.0 5.0 574 KB

A Simple but Powerful CNN Trainer For PyTorch

Home Page: https://oke-aditya.github.io/pytorch_cnn_trainer/

License: GNU General Public License v3.0

Python 100.00%
hacktoberfest transfer-learning torchvision

pytorch_cnn_trainer's Introduction

pytorch_cnn_trainer's People

Contributors

hassiahk avatar oke-aditya avatar premalrupnur avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch_cnn_trainer's Issues

Add history like object as in Keras for plotting

๐Ÿš€ Feature

Currently, we return metrics in the train_step and val_step.
Maybe we can use them to create a dictionary of a dictionary containing a list

history = {"train" : {"top1_acc" : [] , "top5_acc" : [], "loss" : [] }, "val" : {"top1_acc" : [] , "top5_acc" : [], "loss" : [] }}

Plus we can provide plotting function for this history using matplotlib too.

Mixed precision training using PyTorch 1.6

๐Ÿš€ Feature

PyTorch 1.6 is round the corner.
It is very simple to add mixed precision training with it. Code can be tested on Kaggle.
Support this feature with fp16 : bool = False parameter in the engine.

Add Support for multi label classification

๐Ÿš€ Feature

Should take some time.
We can add an argument while creating models, num_labels: int = 1. If the user specifies that, then simply we can create an extra dense layer on with the top and return it in the model.

Also, we need to slightly alter the engine to support this. We will get two outs, o1, o2. We would need to compute metrics for both and provide other functions too.

Either we can use the same train_step and val_step or use different too.

Add Stochastic Weighted Average (SWA)

๐Ÿš€ Feature

This is natively supported by PyTorch in nn.utils.swa here.

They have also given a small example

        >>> loader, optimizer, model, loss_fn = ...
        >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
        >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
        >>>                                     T_max=300)
        >>> swa_start = 160
        >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 
        >>>        anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
        >>> for i in range(300):
        >>>      for input, target in loader:
        >>>          optimizer.zero_grad()
        >>>          loss_fn(model(input), target).backward()
        >>>          optimizer.step()
        >>>      if i > swa_start:
        >>>          swa_model.update_parameters(model)
        >>>          swa_scheduler.step()
        >>>      else:
        >>>          scheduler.step()
        >>>
        >>> # Update bn statistics for the swa_model at the end
        >>> torch.optim.swa_utils.update_bn(loader, swa_model) 

We need to add this with the AMP feature as well.

Train and Valid split for CSVDataset

๐Ÿš€ Feature

Sometimes, we only get train.csv and test.csv, so it is best to split the train.csv into train_set and valid_set by default in CSVDataset like how it is in create_folder_dataset.

User can split it into train_set and valid_set after creating CSVDataset like below but it is good to have it out of the box.

split = 0.8
complete_dataset = CSVDataset(df, data_dir, target, transform)
train_split = len(complete_dataset) * split
valid_split = len(complete_dataset) * (1 - split)

train_set, valid_set = torch.utils.data.random_split(
        complete_dataset, [train_split, valid_split]
)

Refactor Tests

๐Ÿš€ Feature

With #1 we can simply run tests on these sanity checkers instead of training !!
Write a test for it and add to the tests folder. An almost complete sample of it is in examples test_torchvision.py file

error when calculate top5_acc with 2 classes

๐Ÿ› Bug

Describe the bug
I 'm training model with 2 class. Error appears at line 61 in pytorch_cnn_trainer/utils.py. By default, maxk always equal 5 because topk is fixed (1,5) in train_step and val_step function. output variable shape is 32x4 so result in error RuntimeError: invalid argument 5: k not in range for dimension

     59     maxk = max(topk)
     60     batch_size = target.size(0)
     61     _, pred = output.topk(maxk, 1, True, True)
     62     pred = pred.t()
     63     correct = pred.eq(target.view(1, -1).expand_as(pred))

To Reproduce
Steps to reproduce the behavior:
Just train model using engine.fit with 2 classes

Expected behavior
training process works with arbitrary number of output class

Screenshots
image

Desktop (please complete the following information):

  • OS: ubuntu 20.04

Additional context

Move examples to docs

  • I have written a few examples for docs.

We can shift them to docs with some extra explanations as well !!

Follow Best Practices while training.

๐Ÿš€ Feature

  • We must follow this and this

  • Some thoughts.

  1. Provide utilities for turning off debuggers.
  2. Follow the standards and provide num_workers, pin_memory good defaults.
  3. Use JIT wherever possible.
  4. If somehow we can show user that he can increase batch size, it would be great.

Pass **kwargs at places

๐Ÿš€ Feature

We can simply use **kwargs as argument and reduce the passing of arguments.

Simple issue which I will leave for open-source contributors.

Feature Explainable CNNs

๐Ÿš€ Feature

I guess this feature should be bigger.

It would be really nice to see how CNN Learns every epoch. We can show some of the mistakes, some of the correct classifications too.

This can lead to step by step understanding of how CNN learnt.

Some Ideas that can be used.

  • Use Captum.ai for more explainable AI stuff powered by PyTorch. This should be simple, out of the box support.
  • Captum support with GradCAM, Guided BackProp, etc. It would be nice to see these results in intermediate steps too.
  • Use Logging, Simply log images which are correct, which are wrong. This can be a tensorboard Callback or simple a file storage option.
  • PyTorch Model Profiling is a new tool added in 1.6. We can make use of that for understanding where CNN bottlenecks and performs how well.

CSVDataset attributes

๐Ÿ› Bug

target attribute has been defined but never used.

def __init__(self, df, data_dir, target, transform):
super().__init__()
self.df = df
self.data_dir = data_dir
self.transform = transform

df will not always have column image_id. It needs to be specified by the user just like target.

img_name = self.df.image_id[idx]

Need to add extension like .png or .jpg to img_path since img_name only contains id of the image.

img_path = os.path.join(self.data_dir, img_name)

PyTorch Lightning trainer

๐Ÿš€ Feature

Simple, recreate this API for PyTorch lightning.
It will solve distributed training issues as well.
We need to remove amp as It would be handled by Pytorch Lightning.

We can make use of metrics package as well.

I would rather like to finish this myself. Do let me know if anyone else is interested.

Turn off NVIDIA Profilers

๐Ÿš€ Feature

As per NVIDIA we should turn off profilers.

Linked to #20

Describe the solution you'd like
A simple utlity function to turn it off

Describe alternatives you've considered
Completely optional step

Additional context
Should help in faster training.

Add Examples for New Features

๐Ÿ““ New <Tutorial/Example>

Lot of new Features were added. We need docs and examples for.

  • Mixed Precision Training with train_step and fit()
  • Stochastic Weighted Averaging (SWA) with train_step() and fit()
  • L2 Gradient Penalty for Training.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.