Coder Social home page Coder Social logo

go-ml-transpiler's Introduction

go-ml-transpiler

go-ml-transpiler

go-ml-transpiler provides methods to export trained machine learning models as Go library, for large-scale applications with real-time constraints.

Supported models

Algorithm Training language
xgboost.XGBClassifier (Scikit-Learn API only) Python
xgboost.XGBRegressor (Scikit-Learn API only) Python
ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier Scala/Spark
ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor Scala/Spark
sklearn.ensemble.RandomForestClassifier Python
sklearn.ensemble.RandomForestRegressor Python

Installation

$ git clone [email protected]:znly/go-ml-transpiler.git
$ cd go-ml-transpiler
$ pip install --user -e .

Minimum requirements

- numpy

Testing

$ make test

Nota Bene:

To run the tests on your machine, you will need go + these additional Python packages:

- xgboost
- scikit-learn

Transpile model

from go_ml_transpiler import Transpiler
import xgboost
import numpy

# You can either load a picklized trained model:
# import pickle
# model = pickle.load(open("model.pickle", "rb"))
#
# Or train a fresh model:

Z = 15
N = 300
X = numpy.random.rand(N, Z)
Y = numpy.random.rand(N) > 0.5

model = xgboost.XGBClassifier(
        max_depth=7,
        n_estimators=10,
        learning_rate=0.1,
        n_jobs=-1,
        verbose=True,
        missing=-1,
        objective="reg:logistic",
        booster="dart",
        seed=0,
        base_score=0.9
    )

model.fit(X, Y)

transpiler = Transpiler(model=model)
transpiled_model = transpiler.transpile(
    package_name="model",
    export_method=True,
    method_name="Predict",
    indent="    ")
transpiler.write("/tmp")

Leverage Spark

With XGBoost you can also train your model using Spark with the following few lines of code:

import org.apache.spark.ml.linalg.Vector
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier

case class TrainingSample(
                   features: Vector,
                   label: Int
                 )

val SPARK_MODEL_DIR: String = ...

val dataset = Seq[TrainingSample](...).toDS()

val model = new XGBoostClassifier()
  .setFeaturesCol("features")
  .setLabelCol("label")
  .fit(dataset)

model.save(SPARK_MODEL_DIR)

Then, just load + transpile your model as previously:

from go_ml_transpiler import Transpiler
from go_ml_transpiler.tools.xgboost.spark_tools import load_spark_model

model = load_spark_model(
    $SPARK_MODEL_DIR/data/$MODEL_FILE_NAME),
    $SPARK_MODEL_DIR)/metadata/$METADATA_FILE_NAME))

transpiler = Transpiler(model=model)
transpiled_model = transpiler.transpile()
transpiler.write($GOLANG_MODEL_DIR)

Examples

When transpiling XGBoost models, output will look like this for the booster files:

package model

func predict0(features []float64) float64 {
    if (features[0] < 0.5) || (features[0] == -1) {
        if (features[1] < 0.5) || (features[1] == -1) {
            if (features[7] < 0.941666901) || (features[7] == -1) {
                if (features[13] < 0.105796114) || (features[13] == -1) {
                    return -0.114391111
                } else {
                    (...)

with the following prediction API:

package model

import (
    "math"
)

func Predict(features []float64) [2]float64 {
    sum := math.Log(9.0)
    sum += predict0(features)
    sum += predict1(features)
    sum += predict2(features)
    proba := 1.0 / (1.0 + math.Exp(sum))
    distribution := [2]float64{proba, 1.0 - proba}
    return distribution
}

Nota Bene:

Output is of type:

  • [k]float64 for k-classes classification
  • float64 for regression

Authors

See AUTHORS for the list of contributors.

References

[1] Gilbert Bernstein Morgan Dixon Amit Levy "Faster Real-Time Classification Using Compilation", 2010.

License License

The Apache License version 2.0 (Apache2) - see LICENSE for more details.

Copyright (c) 2018 Zenly [email protected] @zenlyapp

go-ml-transpiler's People

Contributors

borelien avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.