Coder Social home page Coder Social logo

monamel / mnist_classification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from vaquierm/mnist_classification

0.0 0.0 0.0 6.36 MB

๐Ÿ–ผ A supervised classification model, which predicts the labels of a Modified MNIST dataset. 1st place in Kaggle competition

Home Page: https://www.kaggle.com/c/modified-mnist/overview

Python 1.51% Jupyter Notebook 98.49%

mnist_classification's Introduction

MNIST_Classification

This repository implements a supervised classification model, which predicts the labels of the Modified MNIST dataset. This model was used in a Kaggle competition.

The Dataset

The Modified MNIST dataset consists of images with a dark background and three handwritten digits from 0 through 9. The digits were taken from the MNIST dataset. The associated label for each image is the digit with the greatest numerical value. The Modified MNIST dataset can be found here.

The Optimal Strategy

The model with the highest classification accuracy used a Residual Network. The architecture for this model is a modified verison of one that was developped by John Olafenwa. These modifications included:

  • increasing the kernel size of the last average pooling layer, so that it acts as a global average pooling layer
  • increasing the depth of the network

In order to augment the data, an ImageDataGenerator from the Keras library was used. Images from the training set were randomly rotated between -10 and 10 degrees, translated by up to 10% of the image width and height, as well as zoomed in or out by up to 10%.

Five models were trained using five different splits of the training data. Ensembling the predictions made from each model, using majority vote, achieved a Kaggle leaderboard accuracy of 99.133%. The confusion matrix of this strategy is as follows:

The accuracy and loss of the training and test datasets over 50 epochs are as follows:

How to Run the Program

  1. Download the training data and the test data from Kaggle, and place them in the data/ folder
  2. Open the src/config.py file and do the following:
    • While they shouldn't require modification, double check that all filepaths are ok.
    • Select the model you would like to run. This could be CNN, for a convolutional neural netowrk, or ResNet, for a residual network. Update the MODEL variable accordingly. Note that the optimal strategy uses ResNet.
    • If you would like to retrain the model, let retrain_models = True.
    • If you would like to perform transfer learning, let transfer_learning = True.
    • Indicate the fold number you would like to run, by adjusting FOLD_NUMBER accordingly.
  3. Run the unprocessed_predictions.py script.

How to Ensemble Predictions

  1. Place all predictions to be ensembled in the results/ensemble/ folder.
  2. Run the ensemble.py script.

Directory Structure

.
โ”œโ”€โ”€ data
โ”‚   
โ”œโ”€โ”€ models
โ”‚
โ”œโ”€โ”€ results
โ”‚   โ””โ”€โ”€ ensemble
โ”‚
โ””โ”€โ”€ src
    โ”œโ”€โ”€ config.py
    |
    โ”œโ”€โ”€ unprocessed_predictions.py
    โ”œโ”€โ”€ isolated_prediction.py
    โ”œโ”€โ”€ triplet_predictions.py
    |
    โ”œโ”€โ”€ ensemble.py
    |
    โ”œโ”€โ”€ data_analysis.ipynb
    โ”œโ”€โ”€ results_analysis.ipynb
    |
    โ”œโ”€โ”€ data_processing
    โ”‚   โ”œโ”€โ”€ data_loader.py
    โ”‚   โ””โ”€โ”€ number_extraction.py
    |
    โ”œโ”€โ”€ models
    โ”‚   โ”œโ”€โ”€ max_mnist_predictor.py
    โ”‚   โ””โ”€โ”€ models.py
    |
    โ””โ”€โ”€ utils
        โ””โ”€โ”€ fileio.py

The data/ folder holds the training and testing data, in the form of .csv files.

Any results are placed automatically in the results/ folder. These results include confusion matrices, loss and accuracy graphs, as well as .csv files with predictions that can be submitted to Kaggle.

  • results/ensemble/ contains all predictions, in the form of .csv files, to be ensembled.

The models/ folder contains models that have been trained. Newly trained models are automatically stored here, where they can then be used to make predictions or perform transfer learning.

Files in src/:

  • config.py defines which models are to be run, and allows for specific configurations.
  • models/models.py contains the implementations of both a convolutional neural network and a residual network.
  • unprocessed_predictions.py has the scripts to train a model, using the unprocessed Modified MNIST dataset, and produce predictions in the form of Kaggle results.
  • isolated_prediction.py has the scripts to train a model, using the original MNIST dataset, and make predictions on the individual digits of each image.
  • triplet_predictions.py has the scripts to train a model and make predictions, using a modified version of the Modified MNIST dataset, where each image is the concatenation of the three isolated digits.
  • util/fileio.py defines the functionalities needed for reading from and writing to files.
  • ensemble.py includes the implementation of an ensemble method, which takes all predictions that are stored in the results/ensemble/ folder and outputs a prediction in the form of a .csv file, which is placed in the results/ folder.
  • data_processing/ has the scripts necessary for different data processing strategies.
  • data_analysis.ipynb is a jupyter notebook that analyzes the dataset for this classification task.
  • results_analysis.ipynb is a jupyter notebook that analyzes the types of errors made by the model.

mnist_classification's People

Contributors

vaquierm avatar arcaulfield avatar jawaialler avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.