Coder Social home page Coder Social logo

kfold_cv's Introduction

I've written some functions which can help you divide your data set into training and validation sets for n-fold cross-validation.

You should have all of your data points in a matrix X where each row is a separate data point. You should also have a column vector y containing the category (or class label) for each of the corresponding data points.

You will also need to define a column vector 'categories' which just lists the class label values you are using. I require this so that the code doesn't make any assumptions about the values you are using for your class labels. You could use '0' as one of the categories if you want, for example, and the values don't have to be contiguous.

Here are links to each of the functions, with a short description of what each does. There is also a simple example usage at the end.

getVecsPerCat.m - Counts the number of vectors belonging to each category.

computeFoldSizes.m - Pre-compute the size of each of the n folds for each category. The number of folds might not divide evenly into the number of vectors, so this function handles distributing the remainder across the folds.

randSortAndGroup.m - Sorts the vectors by category, and randomizes the order of the vectors within each category.

getFoldVectors.m - For the specified round of cross-validation, selects X_train, y_train (the vectors to use for training, with their associated categories) and X_val, y_val (the vectors to use for validation, with their associated categories).

After calling getFoldVectors, it's up to you to perform the actual training, and compute your validation accuracy on the validation vectors. Below is some sample code for using the above functions, but note that it ommits the actual training and validation steps.

% List out the category values in use.
categories = [0; 1];

% Get the number of vectors belonging to each category.
vecsPerCat = getVecsPerCat(X, y, categories);

% Compute the fold sizes for each category.
foldSizes = computeFoldSizes(vecsPerCat, 10);

% Randomly sort the vectors in X, then organize them by category.
[X_sorted, y_sorted] = randSortAndGroup(X, y, categories);

% For each round of cross-validation...
for (roundNumber = 1 : 10)

% Select the vectors to use for training and cross validation.
[X_train, y_train, X_val, y_val] = getFoldVectors(X_sorted, y_sorted, categories, vecsPerCat, foldSizes, roundNumber);

% Train the classifier on the training set, X_train y_train
% .....................

% Measure the classification accuracy on the validation set.
% .....................

end

kfold_cv's People

Contributors

chrisjmccormick avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.