aaltoml / bayesnewton Goto Github PK
View Code? Open in Web Editor NEWBayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's method.
License: Apache License 2.0
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's method.
License: Apache License 2.0
I note that the initial Pinf
variable for Matern-5/2 kernel is as follows:
Pinf = np.array([[self.variance, 0.0, -kappa],
[0.0, kappa, 0.0],
[-kappa, 0.0, 25.0*self.variance / self.lengthscale**4.0]])
Why it is like that? Any references I should follow up?
PS: by the way, the data ../data/aq_data.csv
is missing.
The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.
Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)
First off, let me tell you. This is an amazing package, amazingly interesting models and very well done. When running some experiments there are missing data. I understand it is not always possible to commit larger files onto github, It would be amazing if there was a readme with data sources or something
Specifically:
data/electricity.csv
But many others that are in gitignore
data/aq_data.csv
data/electricity.csv
data/fission_normalized_counts.csv
data/normalized_alpha_counts.csv
I was trying to find this csv file somewhere else but without some hint in readme I cant find the exact csv. Is it perhaps this one?
https://datahub.io/machine-learning/electricity
Thanks for help
Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,
Regards,
Hi, I am trying to run the air_quality demo, and got the data from here https://drive.google.com/file/d/1QcbzFJyAr95zY-B1r-0MkzcX7TReaF6C/view?pli=1.
On running the code as is, I get this error:
Traceback (most recent call last):
File "/BayesNewton/demos/air_quality.py", line 433, in <module>
mu = Y_scaler.inverse_transform(mu)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ssm/lib/python3.11/site-packages/sklearn/preprocessing/_data.py", line 1085, in inverse_transform
X = check_array(
^^^^^^^^^^^^
File "/ssm/lib/python3.11/site-packages/sklearn/utils/validation.py", line 1043, in check_array
raise ValueError(
ValueError: Found array with dim 3. None expected <= 2.
Any help is appreciated, Thanks!
Dear all,
I am trying to run the demo examples, but I run in the following error
ImportError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
3 import jax.numpy as np
4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
----> 5 from jax.ops import index_add, index
6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
7 from warnings import warn
ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)
I think its related to this from the Jax website:
The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.
Hi, thank you for sharing your great work.
I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you
I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u)
.
Hi!
Many thanks for open-sourcing this package.
I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32
is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,
Periodic
kernel, it is rather sensitive to the matrix operations in _sequential_kf()
and _sequential_rts()
.Matern32
kernel.However, reverting to float64
by setting config.update("jax_enable_x64", True)
makes everything quite slow, especially when I use objax
neural network modules, due to the fact that doing so puts all arrays into double precision.
Currently, my solution is to set the neural network weights to float32
manually, and convert input arrays into float32
before entering the network and the outputs back into float64
. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.
Software and hardware details:
objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56 Driver Version: 460.56 CUDA Version: 11.2
GeForce RTX 3090 GPUs
Thanks in advance.
Best,
Harrison
Just as the title said, how can I understand cavity_distribution_tied
in file basemodels.py
? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?
I'm trying to go through the demo in demos/air_quality.py but noticed the aq_data.csv file that it reads is missing in data/
Hello,
Thank you for open-sourcing your work! It has been incredibly useful to me ^-^
Also, is there a method to model input-dependent observation noise? (Snelson and Ghahramani, 2006)
When I run the code file demos/regression.py
with SparseVariationalGP
, something wrong happens:
AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
It seems that a mistake in method SparseVariationalGP
. Can you help to fix the problem? Thank you very much!
As the title said, there is an error after running heteroscedastic.py
with MarkovPosteriorLinearisationQuasiNewtonGP
method:
File "heteroscedastic.py", line 101, in train_op
model.inference(lr=lr_newton, damping=damping) # perform inference and update variational params
File "/BayesNewton-main/bayesnewton/inference.py", line 871, in inference
mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
File "/BayesNewton-main/bayesnewton/inference.py",
line 1076, in update_variational_params
jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
Can you tell me where it is wrong ? Thanks in advance.
Hi, @wil-j-wil
I randomly stumbled upon your work when researching temporal GPs and found this cool package. (Thanks for such an awesome work behind it)
I am running into the following issue when using your package
For example airline passenger dataset using some default setup:
kernel = QuasiPeriodicMatern32()
likelihood = Gaussian()
model = MarkovVariationalGP(kernel=kernel, likelihood=likelihood, X=X_train, Y=y_train)
mean_in, var_in = model.predict(X=X_train[-1]) # this is inside X_train, E[f] = "perfetto"
mean_out, var_out = model.predict(X=X_test[0]) # this is outside X_train, E[f] = nan
I also get NaN for the first observation of X_train, (see plots below)
kernel = Matern12()
Am I perhaps misunderstanding the purpose of the model or doing something wrong? (Thanks in advance for your help : ). I am just a beginner GP enthusiast looking into what these models are capable of doing)
P.S. I installed bayesnewton (hopefully) according to requirements:
[tool.poetry.dependencies]
python = "3.10.11"
tensorflow-macos = "2.13.0"
tensorflow-probability = "^0.21.0"
numba = "^0.58.0"
gpflow = "^2.9.0"
scipy = "^1.11.2"
pandas = "^2.1.1"
jax = "0.4.2"
jaxlib = "0.4.2"
objax = "1.6.0"
ipykernel = "^6.25.2"
plotly = "^5.17.0"
seaborn = "^0.12.2"
nbformat = "^5.9.2"
scikit-learn = "^1.3.1"
convertbng = "^0.6.43"
The periodic kernel outputs the wrong period of my function, although the prediction result shows the correct period. I've attached my sample code at the end for reproducibility.
In the end I want to train my state space model with BayesNewton and use the resulting state space in my own models.
import bayesnewton
import objax
import numpy as np
import matplotlib.pyplot as plt
import time
def periodicFunction(t, p):
noise_var = 0.01
return np.sin(2*np.pi/p * t) + np.math.sqrt(noise_var) * np.random.normal(0, 1, t.shape)
x = np.arange(0., 4., 0.05)
y = periodicFunction(x, 1.)
var_f = 1.
var_y = 1.
period = 3.
lengthscale_f = 1.
kern = bayesnewton.kernels.Periodic(variance=var_f, period=period, lengthscale=lengthscale_f)
lik = bayesnewton.likelihoods.Gaussian(variance=var_y)
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)
lr_adam = 0.1
lr_newton = 1
iters = 200
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())
@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
model.inference(lr=lr_newton) # perform inference and update variational params
dE, E = energy() # compute energy and its gradients w.r.t. hypers
opt_hypers(lr_adam, dE)
return E
train_op = objax.Jit(train_op)
t0 = time.time()
for i in range(1, iters + 1):
loss = train_op()
print('iter %2d, energy: %1.4f' % (i, loss[0]))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))
y_predict, variance = model.predict_y(x)
print(model.kernel.period)
fig, ax = plt.subplots()
ax.plot(x, y, label="True", linewidth=2.0)
ax.plot(x, y_predict, label="Predict", linewidth=2.0)
ax.set(xlim=(x[0], x[-1]))
plt.show()
This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?
If applicable, I would be happy to submit a PR adding one.
Just let me know.
It is recommended to use the following versions of jax and objax:
jax==0.2.9
jaxlib==0.1.60
objax==1.3.1
This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.
hi
How can I speed up training on GPU such as VariationalGP?
I've been using markovflow.SpatioTemporalSparseCVI
, which implements S2CVI from this paper, for some modeling, and I'm now experimenting with BayesNewton. Is SparseMarkovVariationalGP
with a SpatioTemporalKernel
that has inducing points essentially equivalent to S2CVI?
Also, can the current implementations handle 2 spatial dimensions as well as a temporal dimension?
I'm using the MarkovVariationalGP model to compute the energy function. When I compile the energy function, the returned values are not the same, when I use the non compiled function. Sometimes, I get a NaN as output.
My minimal example is the following:
x = np.arange(0., 4., 0.05)
y = x**2
var_f = 1.
var_y = 1.
lengthscale_f = 1.
kern = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=lengthscale_f)
lik = bayesnewton.likelihoods.Gaussian(variance=var_y)
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)
energy = objax.GradValues(model.energy, model.vars())
energy_Jit = objax.Jit(energy, model.vars())
dE, E = energy()
dE_comp, E_comp = energy_Jit()
print(f"Energy={E}, Gradient of Energy={dE}")
print(f"EnergyComp={E_comp}, Gradient of EnergyComp={dE_comp}")
I digged further and the NaN is the result of gaussian_expected_log_lik
in utils
. The masks used here are not the same in the compiled and uncompiled version. The used mask here is mask_pseudo_y
from the BaseModel. Especially mask
and maskv
can have different values which result in noise=0
. This produces a NaN inmvn_logpdf
.
I've included the function here for easier access:
def gaussian_expected_log_lik(Y, q_mu, q_covar, noise, mask=None):
"""
:param Y: N x 1
:param q_mu: N x 1
:param q_covar: N x N
:param noise: N x N
:param mask: N x 1
:return:
E[log 𝓝(yₙ|fₙ,σ²)] = ∫ log 𝓝(yₙ|fₙ,σ²) 𝓝(fₙ|mₙ,vₙ) dfₙ
"""
if mask is not None:
# build a mask for computing the log likelihood of a partially observed multivariate Gaussian
maskv = mask.reshape(-1, 1)
jax.debug.print("mask={mask}, maskv={maskv}", mask=mask, maskv=maskv)
q_mu = np.where(maskv, Y, q_mu)
noise = np.where(maskv + maskv.T, 0., noise) # ensure masked entries are independent
noise = np.where(np.diag(mask), INV2PI, noise) # ensure masked entries return log like of 0
q_covar = np.where(maskv + maskv.T, 0., q_covar) # ensure masked entries are independent
q_covar = np.where(np.diag(mask), 1e-20, q_covar) # ensure masked entries return trace term of 0
ml = mvn_logpdf(Y, q_mu, noise)
trace_term = -0.5 * np.trace(solve(noise, q_covar))
return ml + trace_term
In my opinion, the compiled and uncompiled versions should output the same value.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.