Coder Social home page Coder Social logo

penguinmenac3 / starttf Goto Github PK

View Code? Open in Web Editor NEW
15.0 2.0 4.0 649 KB

A tensorflow batteries included kit to write tensorflow networks from scratch or use existing ones.

License: MIT License

Python 99.66% Shell 0.34%
tensorflow tfrecords estimator deep-learning deep-learning-library deep-learning-tutorial loss vgg16

starttf's Introduction

starttf - Simplified Deeplearning for Tensorflow License: MIT

This repo aims to contain everything required to quickly develop a deep neural network with tensorflow. The idea is that if you use write a compatible SimpleSequence for data loading and networks based on the StartTFModel, you will automatically obey best practices and have super fast training speeds.

Install

Properly install tensorflow or tensorflow-gpu please follow the official instructions carefully.

Then, simply pip install from the github repo.

pip install starttf

Datasets

Extensions SimpleSequences from opendatalake.simple_sequence.SimpleSequence are supported. They work like keras.Sequence however with an augmentation and a preprocessing function.

For details checkout the readme of opendatalake.

Models

Every model returns a dictionary containing output tensors and a dictionary containing debug tensors

  1. Model Base Classes
  2. Common Encoders
  3. Untrained Backbones

Simple to use tensorflow

Simple Training (No Boilerplate)

There are pre-implemented models which can be glued together and trained with just a few lines. However, before training you will have to create tf-records as shown in the section Simple TF Record Creation. This is actually a full main file.

# Import helpers
from starttf.estimators.tf_estimator import easy_train_and_evaluate
from starttf.utils.hyperparams import load_params

# Import a/your model (here one for mnist)
from mymodel import MyStartTFModel

# Import your loss (here an example)
from myloss import create_loss

# Load params (here for mnist)
hyperparams = load_params("hyperparams/experiment1.json")

# Train model
easy_train_and_evaluate(hyperparams, MyStartTFModel, create_loss, continue_training=False)

Quick Model Definition

Simply implement a create_model function. This model is only a feed forward model.

The model function returns a dictionary containing all layers that should be accessible from outside and a dictionary containing debug values that should be availible for loss or plotting in tensorboard.

import tensorflow as tf

from starttf.models.model import StartTFModel
from starttf.models.encoders import Encoder

Conv2D = tf.keras.layers.Conv2D


class ExampleModel(StartTFModel):
    def __init__(self, hyperparams):
        super(ExampleModel, self).__init__(hyperparams)
        num_classes = hyperparams.problem.number_of_categories

        # Create the vgg encoder
        self.encoder = Encoder(hyperparams)

        #Use the generated model 
        self.conv6 = Conv2D(filters=1024, kernel_size=(1, 1), padding="same", activation="relu")
        self.conv7 = Conv2D(filters=1024, kernel_size=(1, 1), padding="same", activation="relu")
        self.conv8 = Conv2D(filters=num_classes, kernel_size=(1, 1), padding="same", activation=None, name="probs")

    def call(self, input_tensor, training=False):
        """
        Run the model.
        """
        encoder, debug = self.encoder(input_tensor, training)
        result = self.conv6(encoder["features"])
        result = self.conv7(result)
        logits = self.conv8(result)
        probs = tf.nn.softmax(logits)
        return {"logits": logits, "probs": probs}, debug

Quick Loss Definition

def create_loss(model, labels, mode, hyper_params):
    metrics = {}
    losses = {}

    # Add loss
    labels = tf.reshape(labels["probs"], [-1, hyper_params.problem.number_of_categories])
    ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=model["logits"], labels=labels)
    loss_op = tf.reduce_mean(ce)

    # Add losses to dict. "loss" is the primary loss that is optimized.
    losses["loss"] = loss_op
    metrics['accuracy'] = tf.metrics.accuracy(labels=labels,
                                              predictions=model["probs"],
                                              name='acc_op')

    return losses, metrics

Simple TF Record Creation

Fast training speed can be achieved by using tf records. However, usually tf records are a hastle to use the write_data method makes it simple.

from starttf.utils.hyperparams import load_params
from starttf.data.autorecords import write_data

from my_data import MySimpleSequence

# Load the hyper parameters.
hyperparams = load_params("hyperparams/experiment1.json")

# Get a generator and its parameters
training_data = MySimpleSequence(hyperparams)
validation_data = MySimpleSequence(hyperparams)

# Write the data
write_data(hyperparams, PHASE_TRAIN, training_data, 4)
write_data(hyperparams, PHASE_VALIDATION, validation_data, 2)

Tensorboard Integration

Tensorboard integration is simple.

Every loss in the losses dict is automatically added to tensorboard. If you also want debug images, you can add a tf.summary.image() in your create_loss method.

TF Estimator + Cluster Support

If you use the easy_train_and_evaluate method, a correctly configured TF Estimator is created. The estimator is then trained in a way that supports cluster training if you have a cluster.

starttf's People

Contributors

amitlevy21 avatar penguinmenac3 avatar strandtasche avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

starttf's Issues

SegNet Model + Example

Is your feature request related to a problem? Please describe.
Segnet is one of the most important segmentation networks and nicely illustrates concepts like upconvolution. Having a clean implementation would benefit lot's of other potentially interesting network implementations.

Describe the solution you'd like

  1. Implement SegNet-Model (create_model-function) in starttf.models (using the vgg16_encoder implementation as an encoder)
  2. Write a full SegNet example (like the mnist example) where the network is trained on cityscapes dataset.
    2.1. Write a loss.py
    2.2. Write a prepare_training.py
    2.3. Write a hyper_params.json
    2.4. Write a train.py
  3. If required add cityscapes to the opendatalake.

Describe alternatives you've considered
None.

Additional context
This might be a good issue for someone who wants to start contributing to starttf.

Add Estimator support

Add support for default tensorflow estimators.

Also consider changing train method to a ScientificEstimator, which provides a quite similar interface to the tensorflow estimator.

Example: Imagenet (Baselines)

Is your feature request related to a problem? Please describe.
Write a tutorial on using inception_v3 or vgg to classify imagenet images. This could help starters with image classificaton.

Describe the solution you'd like
An example like mnist that shows how to train vgg16 or inception_v3 on image net or a similar but smaller task. Ideally this would include transfer learning. Starting with a pretrained vgg or inception network and learning classification on a smaller dataset. For loading a small dataset that is stored in a folder per class the opendatalake.classification.named_folders package is ideal.

Describe alternatives you've considered
None

Additional context
The example should have a similar structure to the mnist example.
This is an ideal issue for someone who wants to start contributing.

An example contains:

  • hyper_params.json
  • loss.py
  • prepare_training.py
  • train.py

Transform to lib

Consider transforming the project to a library (one for keras one for tf).

This would greatly simplify including in other projects.

Multi loss support

Make switching between multiple losses and comparing results easy.

Suggestion switch create loss pattern to be different. (like below)
Possibly even add multi optimizer support

def create_loss(hyper_params, train_model, validation_model, train_labels, validation_labels=None):
    reports = []
    loss_ops = {}
    validation_loss_ops = {}
    train_labels = tf.reshape(train_labels, [-1, 1000])
    loss_ops["cross_entropy"] = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=train_model["logits"], labels=train_labels))
    if hyper_params.train.loss not in loss_ops.keys():
        raise RuntimeException("Unsupported Loss: " + hyper_params.train.loss)
    if hyper_params.train.optimizer == "RMSProp":
        train_op = tf.train.RMSPropOptimizer(learning_rate=hyper_params.train.learning_rate,
                                             decay=hyper_params.train.decay).minimize(loss_ops[hyper_params.train.loss])
    else:
        raise RuntimeException("Unsupported Optimizer: " + hyper_params.train.optimizer)
    for loss_name in loss_ops.keys():
        tf.summary.scalar('train/' + loss_name, loss_ops[loss_name])
    reports.append(loss_ops[loss_name])

    # Create a validation loss if possible.
    if validation_labels is not None:
        validation_labels = tf.reshape(validation_labels, [-1, 1000])
        validation_loss_ops["cross_entropy"] = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=validation_model["logits"], labels=validation_labels))
        for loss_name in loss_ops.keys():
            tf.summary.scalar('validation/' + loss_name, validation_loss_ops[loss_name])
        reports.append(loss_ops[loss_name])

    return train_op, reports

In hyperparameters json add fields for train.loss and train.optimizer.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.