Coder Social home page Coder Social logo

savan77 / transfer-learning Goto Github PK

View Code? Open in Web Editor NEW
0.0 2.0 1.0 257 KB

Study of methods to counter class-imbalance in deep learning models.

Jupyter Notebook 97.63% Python 2.37%
pytorch pytorch-cnn classification transfer-learning fine-tuning imbalance oversample

transfer-learning's Introduction

Study of Class Imbalance in Deep Learning

Data Preparation

All the data exploration and preparation can be found in data_preparation.ipynb. There is a heavy class-imbalance and it needs to be considered during training.

Model

We finetuned ResNet50 with different number of classes.

Training

Training configurations can be changed in parameters.py.

python train.py

Evaluation

Code for model evaluation can be found in model_evaluation.ipynb

The model was trained for 4 epochs on images with top-20 classes and metrics were calculated on test-set.

Class Top-1 Top-5
Shirts 0.98 1.0
Tshirts 0.91 0.99
Casual Shoes 0.75 0.99
Flip Flops 0.86 0.98
Sandals 0.86 0.99
Formal Shoes 0.93 0.98
Handbags 0.98 0.99
Kurtas 0.95 0.99
Belts 0.99 0.99
Sports Shoes 0.90 0.99
Heels 0.63 0.99
Wallets 0.97 1.00
Tops 0.77 0.99
Briefs 0.96 1.0
Dresses 0.95 1.0
Socks 0.96 0.99
Watches 1.0 1.0
Sunglasses 1.0 1.0
Average 0.90 0.99

Finetuning

Next, we finetuned for 10 epochs on above model on remaining classes and evaluated it on test set. Below table shows the top-1 and top-5 accuracy.

Class Top-1 Top-5
Track Pants 0.84 0.97
Sweatshirts 0.68 0.91
Shorts 0.86 0.98
Clutches 0.98 1.0
Backpacks 0.97 1.0
Caps 0.92 0.98
Trousers 0.91 0.98
Jeans 0.86 0.99
Bra 0.99 1.0
Lounge Pants 0.21 0.78
Duffel Bag 0.78 0.97
Night suits 0.5 0.88
Pendant 0.85 0.97
Tracksuits 0.95 1.0
Tunics 0.37 1.0
Nightdress 0.29 0.85
Ties 0.98 0.98
Leggings 0.14 0.87
Jackets 0.80 0.95
Messenger Bag 0.58 1.0
Mufflers 0.13 1.0
Kurta Sets 1.0 1.0
Accessory Gift Set 1.0 1.0
Bracelet 0.80 1.0
Kurtis 0.75 1.0
Sweaters 0.50 0.94
Ring 0.87 1.0
Scarves 0.63 0.91
Earrings 1.0 1.0
Capris 0.73 0.98
Skirts 0.66 0.92
Headband 1.0 1.0
Sports Sandals 0.9 1.0
Tights 0.0 unk
Travel Accessory 0.2 0.9
Gloves 1.0 1.0
Mobile Pouch 0.33 1.0
Wristbands 1.0 1.0
Stockings 0.29 0.88
Footballs 1.0 1.0
Jumpsuit 1.0 1.0
Free Gifts 0.08 0.16
Necklace and Chains 0.92 1.0
Cufflinks 1.0 1.0
Jewellery Set 0.69 1.0
Laptop Bag 0.62 0.97
Camisoles 0.26 0.91
Rucksacks 0.37 1.0
Swimwear 0.12 0.62
Stoles 0.5 1.0
Basketballs 0.63 1.0
Rain Jacket 1.0 1.0
Churidar 0.75 1.0
Umbrellas 1.0 1.0
Bangle 0.21 1.0
Dupatta 0.5 1.0
Water Bottle 1.0 1.0
Makeup Remover 0.66 0.66
Rain Trousers 1.0 1.0
Waist Pouch 0.0 unk
Average 0.77 0.85

โ€‹

Extension

It was observed that different model gave drastically different accuracy for different class. Above table is for a model which gave highest overall accuracy.

Moreover, here, ResNet50 was used with different fully connected layer. We can try with more fully connected layers or other type of layers to improve the accuracy.

But, the biggest factor here is class imbalance. Accuracy for top-20 classes on test set is really good. On the other side, accuracy for remaining classes on the test set is very low compare to the top-20 classes.

There are two common ways to handle the class imbalance. One is to oversample instances with minority class. The second way is to pass weight to the loss function. The above models were trained with the second scheme as it performed better.

We can still notice the class-imbalance problem in the remaining classes. One possible approach to try is to use oversampling along with loss weights. Here, we will oversample classes with really low number of instances OR classes with really low accuracy.

Alternatively, we can also augment images for minority classes.

The productDisplayName also contains helpful information to classify the image. We can possibly train CNN and RNN jointly to classify the image. The CNN will generate encoding for the image and RNN will generate the encoding for the text followed by a fully connected layer to classify the instance.

Alternatively, we can have above model as it is and trained a seperate RNN model to classify class from productDisplayName. Whenever, our CNN model predicts classes with very low confidence, we can pass the productDisplayName to classify the same instance. The class which has highest confidence from both networks can be considered as the output.

transfer-learning's People

Contributors

savan77 avatar

Watchers

 avatar  avatar

Forkers

vaisili

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.