Coder Social home page Coder Social logo

aaltoml / bayesnewton Goto Github PK

View Code? Open in Web Editor NEW
212.0 11.0 26.0 1.37 MB

Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's method.

License: Apache License 2.0

Python 99.65% Shell 0.35%
machine-learning signal-processing gaussian-processes state-space-models approximate-bayesian-inference jax sparse-gps markov-gps

bayesnewton's People

Contributors

asolin avatar fsaad avatar itskalvik avatar mathdr avatar thorewietzke avatar wil-j-wil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bayesnewton's Issues

How to set initial `Pinf` variable in kernel?

I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

        Pinf = np.array([[self.variance,    0.0,   -kappa],
                         [0.0,    kappa, 0.0],
                         [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])

Why it is like that? Any references I should follow up?

PS: by the way, the data ../data/aq_data.csv is missing.

Sparse EP energy is incorrect

The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

Missing data for experiments, (electricity ...)

First off, let me tell you. This is an amazing package, amazingly interesting models and very well done. When running some experiments there are missing data. I understand it is not always possible to commit larger files onto github, It would be amazing if there was a readme with data sources or something

Specifically:
data/electricity.csv

But many others that are in gitignore

data/aq_data.csv
data/electricity.csv
data/fission_normalized_counts.csv
data/normalized_alpha_counts.csv

I was trying to find this csv file somewhere else but without some hint in readme I cant find the exact csv. Is it perhaps this one?
https://datahub.io/machine-learning/electricity

Thanks for help

jitted predict

Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

Regards,

air_quality.py demo failing

Hi, I am trying to run the air_quality demo, and got the data from here https://drive.google.com/file/d/1QcbzFJyAr95zY-B1r-0MkzcX7TReaF6C/view?pli=1.

On running the code as is, I get this error:

Traceback (most recent call last):
  File "/BayesNewton/demos/air_quality.py", line 433, in <module>
    mu = Y_scaler.inverse_transform(mu)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ssm/lib/python3.11/site-packages/sklearn/preprocessing/_data.py", line 1085, in inverse_transform
    X = check_array(
        ^^^^^^^^^^^^
  File "/ssm/lib/python3.11/site-packages/sklearn/utils/validation.py", line 1043, in check_array
    raise ValueError(
ValueError: Found array with dim 3. None expected <= 2.

Any help is appreciated, Thanks!

Cannot run demo, possible incompatibility with latest Jax

Dear all,

I am trying to run the demo examples, but I run in the following error


ImportError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
3 import jax.numpy as np
4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
----> 5 from jax.ops import index_add, index
6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
7 from warnings import warn

ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

I think its related to this from the Jax website:

The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

Double Precision Issues

Hi!

Many thanks for open-sourcing this package.

I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

  • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
  • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

Software and hardware details:

objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
GeForce RTX 3090 GPUs

Thanks in advance.

Best,
Harrison

Input-dependent noise model

Hello,
Thank you for open-sourcing your work! It has been incredibly useful to me ^-^
Also, is there a method to model input-dependent observation noise? (Snelson and Ghahramani, 2006)

issue with SparseVariationalGP method

When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.

It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

File "heteroscedastic.py", line 101, in train_op
model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
File "/BayesNewton-main/bayesnewton/inference.py",
line 1076, in update_variational_params
jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)

Can you tell me where it is wrong ? Thanks in advance.

NaN predictions when the inference is outside of the training data.

Hi, @wil-j-wil

I randomly stumbled upon your work when researching temporal GPs and found this cool package. (Thanks for such an awesome work behind it)

I am running into the following issue when using your package

  • whenever I use a more complex kernel my predictions that are outside of the X_train values are all NaNs.

For example airline passenger dataset using some default setup:

kernel = QuasiPeriodicMatern32()
likelihood = Gaussian()
model = MarkovVariationalGP(kernel=kernel, likelihood=likelihood, X=X_train, Y=y_train)
  • Again, whenever I predict value "within" the X_train, it works well:
mean_in, var_in = model.predict(X=X_train[-1]) # this is inside X_train, E[f] = "perfetto"
  • Whenever I try to "extrapolate" I get NaNs.
mean_out, var_out = model.predict(X=X_test[0]) # this is outside X_train, E[f] = nan

I also get NaN for the first observation of X_train, (see plots below)

  • This does not happen when using basic kernels such as:
kernel = Matern12()
  • Specifically, it does happen whenever I try to use Sum of kernels or any of the combination kernels.

matern12
qperiodmatern32

Am I perhaps misunderstanding the purpose of the model or doing something wrong? (Thanks in advance for your help : ). I am just a beginner GP enthusiast looking into what these models are capable of doing)

P.S. I installed bayesnewton (hopefully) according to requirements:

[tool.poetry.dependencies]
python = "3.10.11"
tensorflow-macos = "2.13.0"
tensorflow-probability = "^0.21.0"
numba = "^0.58.0"
gpflow = "^2.9.0"
scipy = "^1.11.2"
pandas = "^2.1.1"
jax = "0.4.2"
jaxlib = "0.4.2"
objax = "1.6.0"
ipykernel = "^6.25.2"
plotly = "^5.17.0"
seaborn = "^0.12.2"
nbformat = "^5.9.2"
scikit-learn = "^1.3.1"
convertbng = "^0.6.43"

Periodic Kernel outputs wrong period

The periodic kernel outputs the wrong period of my function, although the prediction result shows the correct period. I've attached my sample code at the end for reproducibility.
In the end I want to train my state space model with BayesNewton and use the resulting state space in my own models.

import bayesnewton
import objax
import numpy as np
import matplotlib.pyplot as plt
import time


def periodicFunction(t, p):
    noise_var = 0.01
    return np.sin(2*np.pi/p * t) + np.math.sqrt(noise_var) * np.random.normal(0, 1, t.shape)


x = np.arange(0., 4., 0.05)
y = periodicFunction(x, 1.)
var_f = 1.
var_y = 1.
period = 3.
lengthscale_f = 1.

kern = bayesnewton.kernels.Periodic(variance=var_f, period=period, lengthscale=lengthscale_f)
lik = bayesnewton.likelihoods.Gaussian(variance=var_y)
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

lr_adam = 0.1
lr_newton = 1
iters = 200
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())


@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton)  # perform inference and update variational params
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)
    return E


train_op = objax.Jit(train_op)

t0 = time.time()
for i in range(1, iters + 1):
    loss = train_op()
    print('iter %2d, energy: %1.4f' % (i, loss[0]))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))

y_predict, variance = model.predict_y(x)

print(model.kernel.period)

fig, ax = plt.subplots()

ax.plot(x, y, label="True", linewidth=2.0)
ax.plot(x, y_predict, label="Predict", linewidth=2.0)
ax.set(xlim=(x[0], x[-1]))

plt.show()

Addition of a Squared Exponential Kernel?

This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

If applicable, I would be happy to submit a PR adding one.

Just let me know.

Latest versions of JAX and objax cause compile slow down

It is recommended to use the following versions of jax and objax:

jax==0.2.9
jaxlib==0.1.60
objax==1.3.1

This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

GPU training

hi
How can I speed up training on GPU such as VariationalGP?

SparseMarkovVariationalGP vs S2CVI

I've been using markovflow.SpatioTemporalSparseCVI, which implements S2CVI from this paper, for some modeling, and I'm now experimenting with BayesNewton. Is SparseMarkovVariationalGP with a SpatioTemporalKernel that has inducing points essentially equivalent to S2CVI?

Also, can the current implementations handle 2 spatial dimensions as well as a temporal dimension?

energy function can output NaN if compiled

I'm using the MarkovVariationalGP model to compute the energy function. When I compile the energy function, the returned values are not the same, when I use the non compiled function. Sometimes, I get a NaN as output.

My minimal example is the following:

x = np.arange(0., 4., 0.05)
y = x**2
var_f = 1.
var_y = 1.
lengthscale_f = 1.

kern = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=lengthscale_f)
lik = bayesnewton.likelihoods.Gaussian(variance=var_y)
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

energy = objax.GradValues(model.energy, model.vars())

energy_Jit = objax.Jit(energy, model.vars())
dE, E = energy()
dE_comp, E_comp = energy_Jit()

print(f"Energy={E}, Gradient of Energy={dE}")
print(f"EnergyComp={E_comp}, Gradient of EnergyComp={dE_comp}")

I digged further and the NaN is the result of gaussian_expected_log_lik in utils. The masks used here are not the same in the compiled and uncompiled version. The used mask here is mask_pseudo_y from the BaseModel. Especially mask and maskv can have different values which result in noise=0. This produces a NaN inmvn_logpdf.

I've included the function here for easier access:

def gaussian_expected_log_lik(Y, q_mu, q_covar, noise, mask=None):
    """
    :param Y: N x 1
    :param q_mu: N x 1
    :param q_covar: N x N
    :param noise: N x N
    :param mask: N x 1
    :return:
        E[log 𝓝(yₙ|fₙ,σ²)] = ∫ log 𝓝(yₙ|fₙ,σ²) 𝓝(fₙ|mₙ,vₙ) dfₙ
    """
    if mask is not None:
        # build a mask for computing the log likelihood of a partially observed multivariate Gaussian
        maskv = mask.reshape(-1, 1)
        jax.debug.print("mask={mask}, maskv={maskv}", mask=mask, maskv=maskv)
        q_mu = np.where(maskv, Y, q_mu)
        noise = np.where(maskv + maskv.T, 0., noise)  # ensure masked entries are independent
        noise = np.where(np.diag(mask), INV2PI, noise)  # ensure masked entries return log like of 0
        q_covar = np.where(maskv + maskv.T, 0., q_covar)  # ensure masked entries are independent
        q_covar = np.where(np.diag(mask), 1e-20, q_covar)  # ensure masked entries return trace term of 0

    ml = mvn_logpdf(Y, q_mu, noise)
    trace_term = -0.5 * np.trace(solve(noise, q_covar))
    return ml + trace_term

In my opinion, the compiled and uncompiled versions should output the same value.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.