Coder Social home page Coder Social logo

kgp's Introduction

Gaussian Processes for Keras

Build Status Coverage Status license

KGP extends Keras with Gaussian Process (GP) layers. It allows one to build flexible GP models with kernels structured with deep and recurrent networks built with Keras. The structured part of the model (the neural net) runs on Theano or Tensorflow. The GP layers use a custom backend based on GPML 4.0 library, and builds on KISS-GP and extensions. The models can be trained in stages or jointly, using full-batch or semi-stochastic optimization approaches (see our paper). For additional resources and tutorials on Deep Kernel Learning and KISS-GP see https://people.orie.cornell.edu/andrew/code/

KGP is compatible with: Python 2.7-3.5.

In particular, this package implements the method described in our paper:
Learning Scalable Deep Kernels with Recurrent Structure
Maruan Al-Shedivat, Andrew Gordon Wilson, Yunus Saatchi, Zhiting Hu, Eric P. Xing
arXiv:1610.08936, 2016.

Getting started

KGP allows to build models in the same fashion as Keras, using the functional API. For example, a simple GP-RNN model can be built and compiled in just a few lines of code:

from keras.layers import Input, SimpleRNN
from keras.optimizers import Adam

from kgp.layers import GP
from kgp.models import Model
from kgp.losses import gen_gp_loss

input_shape = (10, 2)  # 10 time steps, 2 dimensions
batch_size = 32
nb_train_samples = 512
gp_hypers = {'lik': -2.0, 'cov': [[-0.7], [0.0]]}

# Build the model
inputs = Input(shape=input_shape)
rnn = SimpleRNN(32)(inputs)
gp = GP(gp_hypers,
        batch_size=batch_size,
        nb_train_samples=nb_train_samples)
outputs = [gp(rnn)]
model = Model(inputs=inputs, outputs=outputs)

# Compile the model
loss = [gen_gp_loss(gp) for gp in model.output_layers]
model.compile(optimizer=Adam(1e-2), loss=loss)

Note that KGP models support arbitrary off-the-shelf optimizers from Keras.

Further resources:

Installation

KGP depends on Keras and requires either Theano or TensorFlow being installed. The GPML backend requires either MATLAB or Octave and a corresponding Python interface package: Oct2Py for Octave or the MATLAB engine for Python. Generally, MATLAB backend seems to provide faster runtime. However, if you compile the latest version of Octave with JIT and OpenBLAS support, the overhead gets reduced to minimum.

If you are using Octave, you will need the statistics package. You can install the package using Octave-Forge:

$ octave --eval "pkg install -forge -verbose io"
$ octave --eval "pkg install -forge -verbose statistics"

The requirements can be installed via pip as follows (use sudo if necessary):

$ pip install -r requirements.txt

To install the package, clone the repository and run setup.py as follows:

$ git clone --recursive https://github.com/alshedivat/kgp
$ cd kgp
$ python setup.py develop [--user]

The --user flag (optional) will install the package for a given user only.

Note: Recursive clone is required to get GPML library as a submodule. If you already have a copy of GPML, you can set GPML_PATH environment variable to point to your GPML folder instead.

Contribution

Contributions and especially bug reports are more than welcome.

Citation

@article{alshedivat2016srk,
  title={Learning Scalable Deep Kernels with Recurrent Structure},
  author={Al-Shedivat, Maruan and Wilson, Andrew G and Saatchi, Yunus and Hu, Zhiting and Xing, Eric P},
  journal={arXiv preprint arXiv:1610.08936},
  year={2016}
}

License

For questions about the code and licensing details, please contact Maruan Al-Shedivat and Andrew Gordon Wilson.

kgp's People

Contributors

alshedivat avatar bokorn avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.