Coder Social home page Coder Social logo

cl-random-forest's Introduction

cl-random-forest

http://quickdocs.org/badge/cl-random-forest.svg https://travis-ci.org/masatoi/cl-random-forest.svg?branch=master

Cl-random-forest is a implementation of Random Forest for multiclass classification and univariate regression written in Common Lisp. It also includes a implementation of Global Refinement of Random Forest (Ren, Cao, Wei and Sun. “Global Refinement of Random Forest” CVPR2015). This refinement makes faster and more accurate than standard Random Forest.

Features and Limitations

  • Faster and more accurate than other major implementations such as scikit-learn (Python/Cython) or ranger (R/C++)
scikit-learnrangercl-random-forest
MNIST96.95%, 41.72sec97.17%, 69.34sec98.29%, 12.68sec
letter96.38%, 2.569sec96.42%, 1.828sec97.32%, 3.497sec
covtype94.89%, 263.7sec83.95%, 139.0sec96.01%, 103.9sec
usps93,47%, 3.583sec93.57%, 11.70sec94.96%, 0.686sec
  • Supporting parallelization of training and prediction (tested on SBCL, CCL)
  • It also includes Global Pruning algorithm of Random Forest which can make the model extremely compact
  • Currently, multivariate regression is not implemented

Installation

In quicklisp’s local-projects directory,

git clone https://github.com/masatoi/cl-online-learning.git
git clone https://github.com/masatoi/cl-random-forest.git

In Lisp,

(ql:quickload :cl-random-forest)

When using Roswell,

ros install masatoi/cl-online-learning masatoi/cl-random-forest

Usage

Classification

Prepare training dataset

A dataset consists of a target vector and a input data matrix. For classification, the target vector should be a fixnum simple-vector and the data matrix should be a 2-dimensional double-float array whose row corresponds one datum. Note that the target is a integer starting from 0. For example, the following dataset is valid for 4-class classification with 2-dimensional input.

(defparameter *target*
  (make-array 11 :element-type 'fixnum
                 :initial-contents '(0 0 1 1 2 2 2 3 3 3 3)))

(defparameter *datamatrix*
  (make-array '(11 2)
              :element-type 'double-float
              :initial-contents '((-1.0d0 -2.0d0)
                                  (-2.0d0 -1.5d0)
                                  (1.0d0 -2.0d0)
                                  (3.0d0 -1.5d0)
                                  (-2.0d0 2.0d0)
                                  (-3.0d0 1.0d0)
                                  (-2.0d0 1.0d0)
                                  (3.0d0 2.0d0)
                                  (2.0d0 2.0d0)
                                  (1.0d0 2.0d0)
                                  (1.0d0 1.0d0))))

./docs/img/clrf-example-simple.png

Make Decision Tree

To construct a decision tree, MAKE-DTREE function is available. This function receives the number of classes, the data matrix and the target vector and then returns a decision tree object. This function also receives optionally the max depth of the tree and the minimum number of samples in the region the tree divides and the number of trials of splits.

(defparameter *n-class* 4)

(defparameter *dtree*
  (make-dtree *n-class* *datamatrix* *target*
              :max-depth 5 :min-region-samples 1 :n-trial 10))

Next, make a prediction from the constructed decision tree with PREDICT-DTREE function. For example, to predict the first datum in the data matrix with this decision tree, do as follows.

(predict-dtree *dtree* *datamatrix* 0)
;; => 0 (correct class id)

To make predictions for the entire dataset and calculate the accuracy, use TEST-DTREE function.

(test-dtree *dtree* *datamatrix* *target*)
;; Accuracy: 100.0%, Correct: 11, Total: 11

Make Random Forest

To construct a random forest, MAKE-FOREST function is available. In addition to the MAKE-DTREE function arguments, this function receives optionally the number of decision trees and the bagging ratio that is used for sampling from training data to construct each sub decision trees.

(defparameter *forest*
  (make-forest *n-class* *datamatrix* *target*
               :n-tree 10 :bagging-ratio 1.0
               :max-depth 5 :min-region-samples 1 :n-trial 10))

Prediction and test of random forest are done in the almost same way as decision trees. PREDICT-FOREST function and TEST-FOREST function are available for each purpose.

(predict-forest *forest* *datamatrix* 0)
;; => 0 (correct class id)

(test-forest *forest* *datamatrix* *target*)
;; Accuracy: 100.0%, Correct: 11, Total: 11

Global Refinement of Random Forest

Cl-random-forest has a way to improve pre-trained random forest using global information between each decision trees. For this purpose, we make an another dataset from original dataset and pre-trained random forest. When an original datum input into the random forest, the datum enters into a region which corresponds one leaf node for each decision trees. The datum of the new dataset represents which position of leaf node the original datum entered for each decision tree. Then we train a linear classifier (AROW) using this new dataset and the original target.

;; Make refine learner
(defparameter *forest-learner* (make-refine-learner *forest*))

;; Make refine dataset
(defparameter *forest-refine-dataset* (make-refine-dataset *forest* *datamatrix*))

;; Train refine learner
(train-refine-learner *forest-learner* *forest-refine-dataset* *target*)

;; Test refine learner
(test-refine-learner  *forest-learner* *forest-refine-dataset* *target*)

This TRAIN-REFINE-LEARNER function can be used to learn the dataset collectively, but it may be necessary to call this function several times until learning converges. TRAIN-REFINE-LEARNER-PROCESS function is used for training until converged.

(train-refine-learner-process *forest-learner* *forest-refine-dataset* *target*
                              *forest-refine-dev-dataset* *dev-target*)

Global Pruning of Random Forest

Pruning

Global pruning is a method for compactization of the model size of the random forest using information of the global-refinement learner. A leaf node in a decision tree is no longer necessary when its corresponding element of the weight vector of the global-refinement learner has a small value norm.

To prune a forest destructively, after training the global-refinement learner, run PRUNING! function.

;; Prune *forest*
(pruning! *forest* *forest-learner* 0.1)

The third argument is pruning rate. In this case, 10% leaf nodes are deleted.

Re-learning

After pruning, it is required to re-learn the global-refinement learner.

;; Re-learning of refine-learner
(setf *forest-refine-dataset* (make-refine-dataset *forest* *datamatrix*))
(setf *forest-learner* (make-refine-learner *forest*))
(train-refine-learner *forest-learner* *forest-refine-dataset* *target*)
(test-refine-learner  *forest-learner* *forest-refine-dataset* *target*)

The following figure shows the accuracy for test dataset and the number of leaf nodes when repeating pruning and re-learning on the MNIST dataset. We can see that the performance hardly changes even if the number of leaf nodes decreases to about 1/10.

./docs/img/clrf-mnist-pruning.png

Parallelization

The following several functions can be parallelized with lparallel.

  • MAKE-FOREST
  • MAKE-REGRESSION-FOREST
  • MAKE-REFINE-DATASET
  • TRAIN-REFINE-LEARNER

To enable/disable parallelization, set lparallel’s kernel object. For example, to enable parallelization with 4 threads,

;; Enable parallelization
(setf lparallel:*kernel* (lparallel:make-kernel 4))

;; Disable parallelization
(setf lparallel:*kernel* nil)

Regression

Prepare training dataset

In case of classification, the target is a vector of integer values, whereas in regression is a vector of continuous values.

(defparameter *n* 100)

(defparameter *datamatrix*
  (let ((arr (make-array (list *n* 1) :element-type 'double-float)))
    (loop for i from 0 below *n* do
      (setf (aref arr i 0) (random-uniform (- pi) pi)))
    arr))

(defparameter *target*
  (let ((arr (make-array *n* :element-type 'double-float)))
    (loop for i from 0 below *n* do
      (setf (aref arr i) (+ (sin (aref *datamatrix* i 0))
                            (random-normal :sd 0.1d0))))
    arr))

(defparameter *test*
  (let ((arr (make-array (list *n* 1) :element-type 'double-float)))
    (loop for i from 0 below *n*
          for x from (- pi) to pi by (/ (* 2 pi) *n*)
          do (setf (aref arr i 0) x))
    arr))

(defparameter *test-target*
  (let ((arr (make-array *n* :element-type 'double-float)))
    (loop for i from 0 below *n* do
      (setf (aref arr i) (sin (aref *test* i 0))))
    arr))

Make Regression Tree

;; Make regression tree
(defparameter *rtree*
  (make-rtree *datamatrix* *target* :max-depth 5 :min-region-samples 5 :n-trial 10))

;; Testing
(test-rtree *rtree* *test* *test-target*)
; RMSE: 0.09220732459820888d0

;; Make a prediction for first data point of test dataset
(predict-rtree *rtree* *test* 0)
; => -0.08374452528780077d0

Make Random Forest for Regression

;; Make regression tree forest
(defparameter *rforest*
  (make-regression-forest *datamatrix* *target*
                          :n-tree 100 :bagging-ratio 0.6
                          :max-depth 5 :min-region-samples 5 :n-trial 10))

;; Testing
(test-regression-forest *rforest* *test* *test-target*)
; RMSE: 0.05006872795207973d0

;; Make a prediction for first data point of test dataset
(predict-regression-forest *rforest* *test* 0)
; => -0.16540771296145781d0

./docs/img/clrf-regression.png

Author

Satoshi Imai ([email protected])

Licence

This software is released under the MIT License, see LICENSE.txt.

cl-random-forest's People

Contributors

guicho271828 avatar incjung avatar masatoi avatar

Stargazers

 avatar

Watchers

 avatar  avatar  avatar

Forkers

bpecsek

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.