Coder Social home page Coder Social logo

ml_pipeline's Introduction

ml_pipeline

Pipeline for ML training that records (hyper-)parameters and allows restoring a model using those (hyper-)parameters.

Example usage (requires TensorFlow 2.0)*:

export PYTHONPATH=$PYTHONPATH:{PATH_TO_ml_pipeline}
cd xor_example

In this example, we will be training a MLP on a toy XOR dataset.

To run our first experiment, follow these steps:

  1. Make sure that lines 64-77 of run_xor_experiment.py appear as follows:
if __name__ == "__main__":
    create_xor_dataset()

    # setup experiment for the first time
    XOR_experiment = setup_experiment()
    # restore experiment from a params file
    #XOR_experiment = Experiment()
    #XOR_experiment.load("last_params")

    XOR_experiment.set_exp_name("xor")

    XOR_experiment.train()

    XOR_experiment.save()
  1. Run python run_xor_experiment.py.

Let's dive into what is happening here.

  1. First, the setup_experiment() function creates an Experiment object with Model, Trainer, and Evaluator attributes, which have further sub-attributes and hyper-parameters. See run_xor_experiment.py for further details.
def setup_experiment():
    # setup modules & hyperparameters
    XOR_experiment = Experiment()
    if True:
        XOR_model = TF_MLP_Model()
        if True:
            XOR_model.in_size = 2
            XOR_model.hidden_sizes = [20, 20, 20, 20]
            XOR_model.out_size = 1

            XOR_model.activation = "relu"

        XOR_trainer = TF_Trainer()
        if True:
            ...

        XOR_evaluator = Evaluator()

        XOR_experiment.model = XOR_model
        XOR_experiment.trainer = XOR_trainer
        XOR_experiment.evaluator = XOR_evaluator

    return XOR_experiment
  1. Then, the Experiment.set_exp_name("xor") method stores the "xor" string for logging purposes.

  2. Next, the Experiment.train() method calls Trainer.train(Model), which trains the model using the hyper-parameters of the Trainer.

  3. Lastly, the Experiment.save() method then saves all of the hyper-parameters of the Experiment into a params file. This is stored under the xor_YYYY_MM_DD_HH_MM_SS folder, where the "xor" prefix was set with the Experiment.set_exp_name method.

After the experiment has finished running, your working directory should look like this:

xor_example
+-- run_xor_experiment.py
+-- xor_dataset_utils.py
+-- last_params
+-- xor_YYYY_MM_DD_HH_MM_SS
|   +-- params
|   +-- train_log
|       +-- events.out.tfevents.*
|   +-- ckpts
|       +-- ckpt-49.index
|       +-- ckpt-49.data-00000-of-00001

The last_params file is a copy of the xor_YYYY_MM_DD_HH_MM_SS/params file, and is created for convenience.

Running tensorboard --logdir xor_YYYY_MM_DD_HH_MM_SS/train_log/ shows that training has been slow.

exp1_loss

In order to pick up training where we left off, with different hyper-parameters, follow these steps:

  1. First, edit last_params. Change trainer/load_checkpoint_dir to the "xor_YYYY_MM_DD_HH_MM_SS/" folder, trainer/start_epoch to 50, and trainer/optimizer/epsilon to 1e-7.
last_params
{
  "param_path": "experiment",
  "param_name": "Experiment",
  "model": {
    ...
  },
  "trainer": {
    ...
    "optimizer": {
      "param_path": "trainers.tf_utils.optimizers",
      "param_name": "TF_Adam_Optimizer",
      "learning_rate": 0.01,
      "epsilon": 0.1 => 1e-7
    },
    "load_checkpoint_dir": null => "xor_YYYY_MM_DD_HH_MM_SS/",
    "start_epoch": 0 => 50,
    "n_epochs": 50,
    "batch_size": 4,
    "log_period": 1,
    "save_period": 50
  },
  "evaluator": {
    ...
  }
}
  1. Next, edit run_xor_experiment.py. Change lines 64-77 to:
if __name__ == "__main__":
    create_xor_dataset()

    # setup experiment for the first time
    #XOR_experiment = setup_experiment()
    # restore experiment from a params file
    XOR_experiment = Experiment()
    XOR_experiment.load("last_params")

    XOR_experiment.set_exp_name("xor")

    XOR_experiment.train()

    XOR_experiment.save()
  1. Now, running python run_xor_experiment.py will run an experiment with the new hyper-parameters in last_params. It also creates a new folder, which contains the params, logs, and checkpoints of the new experiment.
xor_example
+-- run_xor_experiment.py
+-- xor_dataset_utils.py
+-- last_params
+-- xor_YYYY_MM_DD_HH_MM_SS (new)
|   +-- params
|   +-- train_log
|       +-- events.out.tfevents.*
|   +-- ckpts
|       +-- ckpt-99.index
|       +-- ckpt-99.data-00000-of-00001
+-- xor_YYYY_MM_DD_HH_MM_SS (old)
|   +-- params
|   +-- train_log
|       +-- events.out.tfevents.*
|   +-- ckpts
|       +-- ckpt-49.index
|       +-- ckpt-49.data-00000-of-00001

Running TensorBoard (tensorboard --logdir xor_YYYY_MM_DD_HH_MM_SS (new)/train_log/) shows that the model has now converged with the new hyper-parameters.

exp2_loss

*Install with pip install tensorflow==2.0.0-beta1 or pip install tensorflow-gpu==2.0.0-beta1.

ml_pipeline's People

Contributors

alvinzz avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.