Coder Social home page Coder Social logo

xup5 / nnfabrik Goto Github PK

View Code? Open in Web Editor NEW

This project forked from sinzlab/nnfabrik

0.0 0.0 0.0 742 KB

A generalized model fitting pipeline that houses models, trainers, and datasets in datajoint and returns as well as stores trained models. Main uses some python 3.8 features, which is not convenient for google colab. This fork omits all the python 3.8 features.

Home Page: https://sinzlab.github.io/nnfabrik/

Python 43.18% Makefile 0.27% Jupyter Notebook 56.51% Dockerfile 0.04%

nnfabrik's Introduction

nnfabrik: a generalized model fitting pipeline

Black GitHub Pages

nnfabrik is a model fitting pipeline, mainly developed for neural networks, where training results (i.e. scores, and trained models) as well as any data related to models, trainers, and datasets used for training are stored in datajoint tables.

Why use it?

Training neural network models commonly involves the following steps:

  • load dataset
  • initialize a model
  • train the model using the dataset

While that would fulfill the training procedure, a huge portion of time spent on finding the best model for your application is dedicated to hyper-parameter selection/optimization. Importantly, each of the above-mentioned steps may require their own specifications which effect the resulting model. For instance, whether to standardize the data, whether to use 2 layers or 20 layers, or wether to use Adam or SGD as the optimizer. This is where nnfabrik becomes very handy by keeping track of models trained for every unique combination of hyperparameters.

โš™๏ธ Installation

You can use one of the following ways to install nnfabrik:

1. Using pip

pip install nnfabrik

2. Via GitHub:

pip install git+https://github.com/sinzlab/nnfabrik.git

๐Ÿ’ป Usage

As mentioned above, nnfabrik helps with keeping track of different combinations of hyperparameters used for model creation and training. In order to achieve this nnfabrik would need the necessary components to train a model. These components include:

  • dataset function: a function that returns the data used for training
  • model function: a function that return the model to be trained
  • trainer function: a function that given dataset and a model trains the model and returns the resulting model

However, to ensure a generalized solution nnfabrik makes some minor assumptions about the inputs and the outputs of the above-mentioned functions. Here are the assumptions:

Dataset function

  • input: must have an argument called seed. The rest of the arguments are up to the user and we will refer to them as dataset_config.
  • output: this is up to the user as long as the returned object is compatible with the model function and trainer function

Model function

  • input: must have two arguments: dataloaders and seed. The rest of the arguments are up to the user and we will refer to them as model_config.
  • output: a model object of class torch.nn.Module

Trainer function

  • input: must have three arguments: model, dataloaders and seed. The rest of the arguments are up to the user and we will refer to them as trainer_config. Note that nnfabrik also passes some extra keyword arguments to the trainer function, but for the start simply ignore them by adding **kwargs to your trainer function inputs.
  • output: the trainer returns three objects including:
    • a single value representing some sort of score (e.g. validation correlation) attributed to the trained model
    • a collection (list, tuple, or dictionary) of any other quantity
    • the state_dict of the trained model.

Here you can see an example of these functions to train an MNIST classifier within the nnfabrik pipeline.

Once you have these three functions, all is left to do is to define the corresponding tables. Tables are structured similar to the the functions. That is, we have a Dataset, Model, and Trainer table. Each entry of the table corresponds to an specific instance of the corresponding function. For example one entry of the Dataset table refers to a specific dataset function and a specific dataset_config.

In addition to the tables which store unique combinations of functions and configuration objects, there are two more tables: Seed and TrainedModel. Seed table stores seed values used in the other functions and is automatically passed to dataset, model, and trainer function. TrainedModel is used to store the trained models. Each entry of the TrainedModel table refers to the resulting model from a unique combination of dataset, model, trainer, and seed.

We have pretty much covered the most important information about nnfabrik, and it is time to use it (to see some examples, please refer to the example section). Some basics about the Datajoint Python package (which is the backbone of nnfabrik) might come handy (especially about dealing with tables) and you can learn more about Datajoint here.

๐Ÿ’ก Example

Here, you can find an example of the whole pipeline which might help to understand how different components work together to perform hyper-parameter search.

๐Ÿ“– Documentation

The documentation can be found here. Please note that it is a work in progress.

๐Ÿ› Report bugs (or request features)

In case you find a bug or would like to see some new features added to nnfabrik, please create an issue or contact any of the contributors.

nnfabrik's People

Contributors

konstantinwilleke avatar mohammadbashiri avatar eywalker avatar arnenx avatar maxfburg avatar kklurz avatar christoph-blessing avatar fabiansinz avatar chbehrens avatar kellirestivo avatar claudiusgruner avatar xup5 avatar sacadena avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.