Coder Social home page Coder Social logo

simple_train's Introduction

simple_train

Simple customizable code to train and evaluate a one hidden layer neural network in Tensorflow2.0. The supported dataset is currently CIFAR-10. Two models are available for training:

  • 'One-hidden': one hidden layer neural network with ReLU activations. Implements f(x) = V*ReLU(W*x + b1) + b2, where the bias terms b1 and b2 are optional.
  • 'Linear': same as 'One-hidden', but without ReLU activations.

Usage

To see all options available to customize model architecture and training take a look at args.py or run

python train_model.py -h

This repository includes two scripts to train (train_model.py) and run (run_model.py) a one hidden layer model. You can train a model as easily as

python train_model.py --checkpoint-path $OUTPUT_PATH

and evaluate it on the test dataset using

python run_model.py --checkpoint-path $CHECKPOINT_PATH

Make sure to replace $OUTPUT_PATH with the path where you want the trained model to be saved, and replace $CHECKPOINT_PATH with the path to the checkpoint to be loaded (in this example they can be the same). Here is an example how to train a one hidden layer neural network with 128 hidden units using Adam optimizer:

python train_model.py --checkpoint-path $OUTPUT_PATH \
--architecture onehidden --hidden-units 128 --use-bias\
--optimizer adam --lr 0.001 --epochs 200

You can easily run the network from the saved model:

python run_model.py --checkpoint-path $CHECKPOINT_PATH

The network architecture and other parameters are restored from the saved config file.

Requirements:

  • tensorflow >= 2.0
  • pickle

simple_train's People

Contributors

z-fabian avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.