Coder Social home page Coder Social logo

grownet's Introduction

GrowNet

Original PyTorch implementation of "Gradient Boosting Neural Networks: GrowNet"

Paper at: https://arxiv.org/pdf/2002.07971.pdf

Getting Started

In this paper, we combine the power of gradient boosting with the flexibility and versatility of neural networks and introduce a new modelling paradigm called GrowNet that can build up a DNN layer by layer. Instead of decision trees, we use shallow neural networks as our weak learners in a general gradient boosting framework that can be applied to a wide variety of tasks spanning classification, regression and ranking. We introduce further innovations like adding second order statistics to the training process, and also including a global corrective step that has been shown, both in theory and in empirical evaluation, to provide performance lift and precise fine-tuning to the specific task at hand.

Prerequisites

The code was implemented in Python 3.6.10 and utilized the packages (full list) in requirements.txt file. The platform I used was linux-64. Most important packages you need are the followings:

cudatoolkit=10.1.243 
numpy=1.18.1 
pandas=1.0.0 
python=3.6.10 
pytorch=1.4.0 

Installing

To run the code, You may create a conda environment (assuming you already have miniconda3 installed) by the following command on terminal:

conda create --name grownet --file requirements.txt

Data

You can download the datasets used in the paper from Google Drive. Create a data folder under GrowNet and put the data under this folder.

1. Classification:

The original HIGGS data is splitted into train and test (same as done in XGBoost paper) using higgs2libsvm.py script.

2. Learning to Rank (L2R):

Data Loading pipeline for L2R task is implemented by taking Microsoft (MSLR-WEB10K) dataset as a baseline. We also converted Yahoo into tis format (The jupyter notebook "yahoo2mslr" does this conversion). Thus if you want to use some other L2R datasets with GrowNet, please convert it into MSLR format. Below you can find a simple sample with just 10 features:

0 qid:10 1:2 2:0 3:0 4:0 5:2 6:0.666667 7:0 8:0 9:0 10:0.666667 

The first feature is label, second is query id and the rest are data features.

3. Regression:

Training and test splits of regression datasets are done in jupyter notebook "reg_train_test_split". The data link already contains splitted data.

Experiments

To reproduce the results from pape, first activate conda virtual environment

conda activate grownet

Then simply navigate to the task folder: Classification, L2R or Regression and execute the following command on terminal:

./train.sh

You may change the dataset, number of hidden layers, number of hidden units in hidden layers, batch size, learning rate and etc from train.sh.

The results may vary 1% or less between identical runs due to random initialization.

Contact

Feel free to drop me an email if you have any questions: [email protected]

Acknowledgments

grownet's People

Contributors

arrrrgh avatar khoadoan avatar sbadirli avatar xuanqing94 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.