Coder Social home page Coder Social logo

maka89 / deep-kernel-gp Goto Github PK

View Code? Open in Web Editor NEW
93.0 9.0 27.0 12.46 MB

Deep Kernel Learning. Gaussian Process Regression where the input is a neural network mapping of x that maximizes the marginal likelihood

Python 100.00%
machine-learning neural-network neural-networks deep-kernel-learning deep-learning deep-neural-networks deeplearning gaussian-processes gp-regression dkl

deep-kernel-gp's Introduction

Deep-Kernel-GP

Dependencies

The package has numpy and scipy.linalg as dependencies. The examples also use matplotlib and scikit-learn

Introduction

Instead of learning a mapping X-->Y with a neural network or GP regression, we learn the following mappings: X-->Z-->Y where the first step is performed by a neural net and the second by a gp regression algorithm.

This way we are able to use GP Regression to learn functions on data where the the assumption that y(x) is a gaussian surface with covariance specified by one of the standard covariance fucntions, might not be a fair assumption. For instance we can learn functions with image pixels as inputs or functions with length scales that varies with the input.

The parameters of the neural net are trained maximizing the log marginal likelihood implied by z(x_train) and y_train.

Deep Kernel Learning - A.G. Wilson ++

Using Deep Belief Nets to Learn Covariance Kernels for Gaussian Processes - G. Hinton ++

Examples

Basic usage is done with a Scikit ish API:

layers=[]
layers.append(Dense(32,activation='tanh'))
layers.append(Dense(1))
layers.append(CovMat(kernel='rbf'))

opt=Adam(1e-3) # or opt=SciPyMin('l-bfgs-b')

gp=NNRegressor(layers,opt=opt,batch_size=x_train.shape[0],maxiter=1000,gp=True,verbose=True)
gp.fit(x_train,y_train)
y_pred,std=gp.predict(x_test)

The example creates a mapping z(x) where both x and z are 1d vectors using a neural network with 1 hidden layer. The CovMat layer creates a covariance matrix from z using the covariance function v*exp(-0.5*|z1-z2|**2) with noise y where x and y are learned during training.

x and y are available after training as gp.layers[-1].var and gp.layers[-1].s_alpha. The gp.fast_forward() function can be used to extract the z(x) function (It skips the last layer that makes an array of size [batch_size, batch_size]).

Learning a function with varying length scale

In the example.py script, deep kernel learning (DKL) is used to learn from samples of the function sin(64(x+0.5)**4).

Learning this function with a Neural Network would be hard, since it can be challenging to fit rapidly oscilating functions using NNs. Learning the function using GPRegression with a squared exponential covariance function, would also be suboptimal, since we need to commit to one fixed length scale. Unless we have a lot of samples,we would be forced to give up precision on the slowly varying part of the function.

DKL Prediction:

DKL Prediction

z(x) function learned by neural network.

We see that DKL solves the problem quite nicely, given the limited data. We also see that for x<-0.5 the std.dev of the DKL model does not capture the prediction error.

deep-kernel-gp's People

Contributors

maka89 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-kernel-gp's Issues

Optimisation does not converge

Could you please make a comment as to why there is no convergence in example.py (1d)? Yet, the end GP is a fairly good approximation.

Thank you!

Epoch 9988: 0 %. Loss: -20.101368667360404
Epoch 9989: 0 %. Loss: -38.94774396791739
Epoch 9990: 0 %. Loss: -39.56470825273037
Epoch 9991: 0 %. Loss: -12.318895364429565
Epoch 9992: 0 %. Loss: -19.47904221404525
Epoch 9993: 0 %. Loss: -36.77490064730489
Epoch 9994: 0 %. Loss: 43.20385940066373
Epoch 9995: 0 %. Loss: 20.232737011431396
Epoch 9996: 0 %. Loss: -34.98050499495804
Epoch 9997: 0 %. Loss: -32.932921142632466
Epoch 9998: 0 %. Loss: 739.1653327084057
Epoch 9999: 0 %. Loss: -39.727655479961854
Epoch 10000: 0 %. Loss: -35.62156909134845

Memory Error:

Hi.
I have 64GB RAM UBUNTU

Not sure but when I'm trying to fit per column basis, I get an error:
tmp=X[:,i].reshape(1,-1)-X[:,i].reshape(-1,1)
MemoryError

My dataset shape is: (187630,14) numerical values
The error refers to the code in :

tmp=X[:,i].reshape(1,-1)-X[:,i].reshape(-1,1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.