Coder Social home page Coder Social logo

sinzlab / nnfabrik Goto Github PK

View Code? Open in Web Editor NEW
17.0 9.0 23.0 6.62 MB

A generalized model fitting pipeline that houses models, trainers, and datasets in datajoint and returns as well as stores trained models.

Home Page: https://sinzlab.github.io/nnfabrik/

Python 43.19% Jupyter Notebook 56.50% Dockerfile 0.04% Makefile 0.27%

nnfabrik's Introduction

nnfabrik: a generalized model fitting pipeline

Black GitHub Pages

nnfabrik is a model fitting pipeline, mainly developed for neural networks, where training results (i.e. scores, and trained models) as well as any data related to models, trainers, and datasets used for training are stored in datajoint tables.

Why use it?

Training neural network models commonly involves the following steps:

  • load dataset
  • initialize a model
  • train the model using the dataset

While that would fulfill the training procedure, a huge portion of time spent on finding the best model for your application is dedicated to hyper-parameter selection/optimization. Importantly, each of the above-mentioned steps may require their own specifications which effect the resulting model. For instance, whether to standardize the data, whether to use 2 layers or 20 layers, or wether to use Adam or SGD as the optimizer. This is where nnfabrik becomes very handy by keeping track of models trained for every unique combination of hyperparameters.

โš™๏ธ Installation

You can use one of the following ways to install nnfabrik:

1. Using pip

pip install nnfabrik

2. Via GitHub:

pip install git+https://github.com/sinzlab/nnfabrik.git

๐Ÿ’ป Usage

As mentioned above, nnfabrik helps with keeping track of different combinations of hyperparameters used for model creation and training. In order to achieve this nnfabrik would need the necessary components to train a model. These components include:

  • dataset function: a function that returns the data used for training
  • model function: a function that return the model to be trained
  • trainer function: a function that given dataset and a model trains the model and returns the resulting model

However, to ensure a generalized solution nnfabrik makes some minor assumptions about the inputs and the outputs of the above-mentioned functions. Here are the assumptions:

Dataset function

  • input: must have an argument called seed. The rest of the arguments are up to the user and we will refer to them as dataset_config.
  • output: this is up to the user as long as the returned object is compatible with the model function and trainer function

Model function

  • input: must have two arguments: dataloaders and seed. The rest of the arguments are up to the user and we will refer to them as model_config.
  • output: a model object of class torch.nn.Module

Trainer function

  • input: must have three arguments: model, dataloaders and seed. The rest of the arguments are up to the user and we will refer to them as trainer_config. Note that nnfabrik also passes some extra keyword arguments to the trainer function, but for the start simply ignore them by adding **kwargs to your trainer function inputs.
  • output: the trainer returns three objects including:
    • a single value representing some sort of score (e.g. validation correlation) attributed to the trained model
    • a collection (list, tuple, or dictionary) of any other quantity
    • the state_dict of the trained model.

Here you can see an example of these functions to train an MNIST classifier within the nnfabrik pipeline.

Once you have these three functions, all is left to do is to define the corresponding tables. Tables are structured similar to the the functions. That is, we have a Dataset, Model, and Trainer table. Each entry of the table corresponds to an specific instance of the corresponding function. For example one entry of the Dataset table refers to a specific dataset function and a specific dataset_config.

In addition to the tables which store unique combinations of functions and configuration objects, there are two more tables: Seed and TrainedModel. Seed table stores seed values used in the other functions and is automatically passed to dataset, model, and trainer function. TrainedModel is used to store the trained models. Each entry of the TrainedModel table refers to the resulting model from a unique combination of dataset, model, trainer, and seed.

We have pretty much covered the most important information about nnfabrik, and it is time to use it (to see some examples, please refer to the example section). Some basics about the Datajoint Python package (which is the backbone of nnfabrik) might come handy (especially about dealing with tables) and you can learn more about Datajoint here.

๐Ÿ’ก Example

Here, you can find an example of the whole pipeline which might help to understand how different components work together to perform hyper-parameter search.

๐Ÿ“– Documentation

The documentation can be found here. Please note that it is a work in progress.

๐Ÿ› Report bugs (or request features)

In case you find a bug or would like to see some new features added to nnfabrik, please create an issue or contact any of the contributors.

nnfabrik's People

Contributors

arnenx avatar chbehrens avatar christoph-blessing avatar claudiusgruner avatar eywalker avatar fabiansinz avatar kellirestivo avatar kklurz avatar konstantinwilleke avatar maxfburg avatar mohammadbashiri avatar sacadena avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nnfabrik's Issues

Proposal: Abstract classes for datasets, models and trainers

Hi all,

as I am starting to work with nnfabrik, I would like to suggest something to discuss here.

I believe that the minimal requirements that any trainer, model or dataset needs to fulfill should be formulated as an abstract class from which everyone has to inherit.
This way we could formalize the general attributes and methods any trainer/model/dataset needs to fulfill (similar to the description in toy_datasets.py...). This would make the whole framework more modular since any class interacting with others could rely on those attributes that were guaranteed in the superclass.

Of course, this is only a suggestion, and I would love to hear your opinion on this topic.
I would also be happy to start on a pull-request for this. (Although I will probably need some input from all of you, as there are lots of design decisions to be done.)

KeyError in TrainedModel when no data_info_table is defined

For a trained model table where nnfabrik references to nnfabrik.main and no data_info_table is explicitly defined

@property
def data_info_table(self):
return find_object(self.nnfabrik, "DataInfo", "data_info_table")

throws a key error as it tries to find the non-existent DataInfo in nnfabrik.main. Is this wanted behavior or should there be something like setting data_info_table = None maybe combined with a warning?

Readme not up-to-date

The readme is not up-to-date. Mainly the links are broken.
But we may check if all descriptions are still accurate.

Refactor

Points from nnFabrik discussion:

  • address the naming conventions as discussed in #17
  • keep the old main.py as a legacy module (for a limited time) so that transition to new column names would be smoother.

dataset_fn arguments

  • returns a dictionary of dataloaders where elements returned by the dataloader are named tuple.
  • keyword args:
    • seed

model_fn arguments

  • keyword args:
    • dataloaders dictionary
    • seed

trainer_fn arguments

  • keyword args:
    • model
    • dataloaders dictionary
    • seed
  • do not unpack dataloaders dictionary when passing it as an argument to trainer_fn
  • in the training_step function the named tuple structure should be taken care of

Naming Conventions for the Dataloader

  • "train"
  • "validation"
  • "test"

Introduce input arguments to make choice of losses (e.g., averaged vs not) configurable by the user

# return np.sqrt(m / k) * criterion(model(inputs, data_key), targets).sum() + model.regularizer(data_key)
return criterion(model(inputs, data_key), targets) + model.regularizer(data_key)
##### This is where everything happens ################################################################################
model.train()
criterion = getattr(measures, loss_function)(per_neuron=False, avg=True)

Discourage saving the model's state dict for pretrained models.

By design, the TrainedModel.Storage is storing the weights to the complete model. For pretrained models, this can easily lead to an explosion of storage space, which can be easily avoided.
The trainers of nnfabrik should encourage to not return the full state dict, but only the weights that were learned during training. The nnfabik builder can then be called with load_model(strict=False).

Extend README/documentation by datatypes

From reading the readme, it isn't obvious what the structure of the parameter datasets for model and trainer function is.
This should be explicitly stated to ensure code is compatible with nnfabrik standards from the beginning.
When reading the examples, it seems like it's simply a torch.utils.data.DataLoader object.

Fix the bug in DataInfoBase table

The following lines throws an error:

dataset_fn, dataset_config = (self.dataset_table & key).fn_config
data_info = dataset_fn(**dataset_config, return_data_info=True)

Because fn_config returns a string as dataset_fn:

nnfabrik/nnfabrik/main.py

Lines 233 to 236 in 744ae10

def fn_config(self):
dataset_fn, dataset_config = self.fetch1("dataset_fn", "dataset_config")
dataset_config = cleanup_numpy_scalar(dataset_config)
return dataset_fn, dataset_config

Not compatible with Python3.10

With Python3.10 returns an error

File ...lib/python3.10/site-packages/nnfabrik/utility/dj_helpers.py:12, in <module>
     10 from datetime import date, datetime
     11 from datajoint.utils import to_camel_case
---> 12 from collections import OrderedDict, Iterable, Mapping

ImportError: cannot import name 'Iterable' from 'collections' (...lib/python3.10/collections/__init__.py)

Seems like the Iterable was deprecated and the package needs a slight update

Centralized caching of checkpoints during training

Interruptions to the training process are a common problem for me and everyone who runs experiments that take a little longer (e.g. >3 hours in my case). There are plenty of possible causes for such interruption: unexpected shutdown of the server, someone kills your container, you have an error in the final evaluation,... This is especially annoying and wasteful towards the end of the training.

My current solution is to locally save the current best checkpoint (consisting of model-state, optimizer-state, epoch count and current accuracy) with the experiment hash as the filename. When I (re-)start a training, I check if any such checkpoint exists and load it if it does.
The main problem with this approach is that it happens locally. Therefore, if the training starts on another server, it will not be loaded.

An ideal solution to this problem would be to save checkpoints in another table and update them as training progresses. However, I am not sure if this is trivial since the training is executed within a make function.

Further features I would like to include in this:

  • Option to save not only the best checkpoint, but also the last N checkpoints or the best N. (This could be useful for stochastic weight averaging)
  • Models should be saved on Minio.

Syntax Error with Python 3.7

The newly merged pull-request ( #33) introduces some refactoring for error handling, e.g.:

except NameError, TypeError as e:
            warnings.warn(str(e) + '\nTable entry rejected')
            return

Unfortunately, this seems to be incompatible with python syntax in version 3.7. I think the error can be avoided by using parentheses:

except (NameError, TypeError) as e:
            warnings.warn(str(e) + '\nTable entry rejected')
            return

Consistency in naming convention

Inconsistency from one table to another: for Model() table we have configurator, config_hash, and config_obj which could be more descriptive like other tables.
Inconsistency within one table: for Trainer() table we have both trainer and training. I suggest using trainer only.

Make Hypersearch module deal with more than one seed

Currently, the Bayesian and Random class of the hypersearch module do not have a way to deal with seeds. If there are multiple seeds in the seed table, then the results for a given parameter setting should ideally be averaged. A keyword argument for seed to select a fixed seed could also be an option.

Minimal example of SinzLab use-case

Let us keep nnfabrik as generic as possible but also encourage unity of approach among neuroscience labs.

What needs to be done?

  • new example of dataset, model, and trainer functions
  • a new notebook of showing how to use the functions
  • change of readme to point to this use-case which we would encourage computational neuroscience labs to follow

Error when there are more than two variables return by dataloader

The following line makes and assumption about how many variables are returned by the dataloader and results in an error if behvioral variables are returned as third variable.

in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields

If we only care about the input and output shapes, we can do:

in_name, out_name = next(iter(list(dataloaders.values())[0]))._fields[:2]

Note that this makes an assumption that the first two variables are inputs and outputs, respectively. (which I think is reasonable)

recommend mysql settings

Hi,

We are in the process of setting up a datajoint server (using https://github.com/datajoint/mysql-docker) for the Mackelab and are trying out nnfabrik. If I understood correctly by default the Fabrikant table is shared by all users on the nnfabrik_core schema whereas the other table can be on any other schema accessible by the user.

What mysql permissions do you generally give users on the nnfabrik_core schema? It seems a bit risky to give users all access, so I'm guessing there might be a subset that will still allow users to execute e.g. the example notebooks but not allow them to delete information inserted by other users (and thus the tables depending on that information) in the Fabrikant table.

Thank you very much!

Documentation of tables and functions + Best Practice

The user does, as of now, not have a good idea what the expectations of nnfabrik are. The specifications are not listed explicitly. So the functions, specifically in builder.py, have to extensively documented. As well as the tables in main.py.
The logic of inputs and returns of models/datasets/trainers should be documented extensively.

Supply the user with best practices. This should be part of the documentation.
Part of the best practice:

  • All models should accept inputs and data_key, as well as **kwargs. So that every model works with every dataset. but the model has to warn the user if kwargs are ignored (otherwise its prone to error).

Cleaning Up before going public

  • make utils.metrics.py more general, and give it better names. rename the module to evaluation.
  • Adding a callback and a key in the trainedmodel.make, to be passed on to the trainer. This will result in a nnfabrik trainer to require five default arguments. Add to documentation.

Throw error if current user is not found

Right now the get_current_user method in the Fabrikant class returns None if it doesn't find an entry in the Fabrikant table corresponding to the current Datajoint username. This behavior of returning None can lead to non-helpful error messages in certain situations. For example trying to insert a dataset into the Dataset table without a matching entry being present in the Fabrikant table raises this error:

MissingAttributeError: Field 'dataset_fabrikant' doesn't have a default value

I propose to raise a more descriptive error in the get_current_user method if there is no matching entry found.

Endpoints and keys in datajoint config are allowed to be undefined

The datajoint config in the main are allowed to be unset (namely MINIO_ENDPOINT, MINIO_ACCESS_KEY and MINIO_SECRET_KEY).

In my case this lead to a runtime error thrown by the Bayesian hypersearch because storing wasn't possible.

Is there an obvious usecase where you don't need the storage defined?
If there isn't, I would suggest replacing the function call by directly accessing the os.environ object allowing for an exception to be raised. If the same behavior is desired, I'd add a try except around that with a user warning in the except.

Keyword argument processing and parameter optimization in early_stopping_trainer

Due to a rush to make nnfabrik available in the neural prediction challenge, not all change requests in the last two PRs to nnfabrik have been properly implemented.

These involve:

  • #16 (comment)
    use flexible keyword arguments when using the dataloader, which might return more entries that inputs and targets (eye movements, etc.)

  • #16 (comment)
    Same logic also for the objective and early_stop_functions

  • Consistent Google Style Docstrings and expansion of documentation

  • #18 (comment)
    Modify the transferLearningCore so that only trainable parameters are returned for model.parameters()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.