Coder Social home page Coder Social logo

cola's Introduction

Compositional Linear Algebra (CoLA)

Documentation tests codecov PyPI version Paper Downloads

CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA natively supports PyTorch, Jax, as well as (limited) Numpy if Jax is not installed.

Installation

pip install cola-ml

Features in CoLA

  • Large scale linear algebra routines for solve(A,b), eig(A), logdet(A), exp(A), trace(A), diag(A), sqrt(A).
  • Provides (user extendible) compositional rules to exploit structure through multiple dispatch.
  • Has memory-efficient autodiff rules for iterative algorithms.
  • Works with PyTorch or JAX, supporting GPU hardware acceleration.
  • Supports operators with complex numbers and low precision.
  • Provides linear algebra operations for both symmetric and non-symmetric matrices.

See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.

Quick start guide

  1. LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them +, -, multiply by constants *, /, matrix multiply them @ and combine them in other ways: kron, kronsum, block_diag etc.
import jax.numpy as jnp
import cola

A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)

v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2       0.2       2.2       2.2       4.2       4.2       6.2
 6.2       8.2       8.2       7.8       2.1    ]
  1. Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j  0.0000000e+00+0.j  2.1999998e+00+0.j
 -1.1920929e-07+0.j  4.1999998e+00+0.j]
diag([0.31622776 1.0488088  1.4491377  1.7606816  2.0248456 ])

For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.

Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
  1. JAX and PyTorch. We support both ML frameworks.
import torch
A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))

import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0

CoLA also supports autograd (and jit):

from jax import grad, jit, vmap


def myloss(x):
    A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
    return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)


g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]

Citing us

If you use CoLA, please cite the following paper:

Andres Potapczynski, Marc Finzi, Geoff Pleiss, and Andrew Gordon Wilson. "CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra." 2023.

@article{potapczynski2023cola,
  title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
  author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
  journal={arXiv preprint arXiv:2309.03060},
  year={2023}
}

Features implemented

Linear Algebra inverse eig diag trace logdet exp sqrt f(A) SVD pseudoinverse
Implementation
LinearOperators Diag BlockDiag Kronecker KronSum Sparse Jacobian Hessian Fisher Concatenated Triangular FFT Tridiagonal
Implementation
Annotations SelfAdjoint PSD Unitary
Implementation
Backends PyTorch Jax Numpy
Implementation Most operations

Contributing

See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.

CoLA is Apache 2.0 licensed.

Support and contact

Please raise an issue if you find a bug or slow performance when using CoLA.

cola's People

Contributors

andpotap avatar andrewgordonwilson avatar eltociear avatar fr0do avatar gpleiss avatar mfinzi avatar pitmonticone avatar raulpl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cola's Issues

Worksheet GP example fails with ValueError

Complete code:

!wget -O bike.mat "https://www.andpotap.com/static/bike.mat"

from jax import numpy as jnp
import os
import numpy as np
from math import floor
from scipy.io import loadmat
import cola


def load_uci_data(data_dir, dataset, train_p=0.75, test_p=0.15):
    file_path = os.path.join(data_dir, dataset + '.mat')
    data = np.array(loadmat(file_path)['data'])
    X = data[:, :-1]
    y = data[:, -1]

    X = X - X.min(0)[None]
    X = 2.0 * (X / X.max(0)[None]) - 1.0
    y -= y.mean()
    y /= y.std()

    train_n = int(floor(train_p * X.shape[0]))
    valid_n = int(floor((1. - train_p - test_p) * X.shape[0]))

    split = split_dataset(X, y, train_n, valid_n)
    train_x, train_y, valid_x, valid_y, test_x, test_y = split

    return train_x, train_y, test_x, test_y, valid_x, valid_y


def split_dataset(x, y, train_n, valid_n):
    train_x = x[:train_n, :]
    train_y = y[:train_n]

    valid_x = x[train_n:train_n + valid_n, :]
    valid_y = y[train_n:train_n + valid_n]

    test_x = x[train_n + valid_n:, :]
    test_y = y[train_n + valid_n:]
    return train_x, train_y, valid_x, valid_y, test_x, test_y


train_x, train_y, *_, test_x, test_y = load_uci_data(data_dir="./", dataset="bike")

dtype = jnp.float32
train_x, train_y = jnp.array(train_x, dtype=dtype), jnp.array(train_y, dtype=dtype)
test_x, test_y = jnp.array(test_x, dtype=dtype), jnp.array(test_y, dtype=dtype)

train_x, train_y = train_x[:1000], train_y[:1000]

def compute_rbf_cov(xi, xj):
    xi, xj = jnp.expand_dims(xi, -2), jnp.expand_dims(xj, -3)
    res = jnp.exp(jnp.sum((xi - xj)**2, axis=-1))
    return res

ls = jnp.array(100., dtype=dtype)
noise = jnp.array(1., dtype=dtype)
oscale = jnp.array(1., dtype=dtype)
K_train_train = cola.ops.Dense(oscale * compute_rbf_cov(train_x / ls, train_x / ls))
K_test_train = cola.ops.Dense(oscale * compute_rbf_cov(test_x / ls, train_x / ls))
K_test_test = cola.ops.Dense(oscale * compute_rbf_cov(test_x / ls, test_x / ls))
K = cola.ops.PSD(K_train_train + noise * cola.ops.I_like(K_train_train))
mu = K_test_train @ inverse(K) @ train_y
Sigma = K_test_test - K_test_train @ inverse(K) @ K_test_train.T

The difference of matrices operation - fails.

Error logs

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[71], line 10
      8 mu = K_test_train @ inverse(K) @ train_y
      9 A = K_test_train @ inverse(K) @ K_test_train.T
---> 10 Sigma = K_test_test + A

File [~/cola/cola/ops/operator_base.py:119](/cola/docs/notebooks/~/cola/cola/ops/operator_base.py:119), in LinearOperator.__add__(self, other)
    118 def __add__(self, other):
--> 119     if other == 0:
    120         return self
    121     return cola.fns.add(self, other)

File [/.conda/envs/cola/lib/python3.10/site-packages/jax/_src/array.py:257](/.conda/envs/cola/lib/python3.10/site-packages/jax/_src/array.py:257), in ArrayImpl.__bool__(self)
    256 def __bool__(self):
--> 257   return bool(self._value)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Note that it works from the slice of 1000 samples is changed to slice of 2000 samples.

[Feature Request] Low Rank Dispatch rules (Woodbury identity, Trace rules, etc)

🚀 Feature Request

Dispatch rules for Product[Dense,Dense] or Sum[Product[Dense,Dense], Diagonal].

Examples:

Woodbury identity

Let the Woodbury matrix identity be given by:
$(D + UV)^{-1} = D^{-1} - D^{-1}U(I + VD^{-1}U)^{-1}VD^{-1}$

Cyclic trace property:

Given the cyclic trace property:
$\text{Tr}(UV) = \text{Tr}(VU)$
The idea is that for a generic Product[LinearOperator,LinearOperator] where $U$ and $V$ are not square, we can rearrange to reduce the dimensionality. If dense, we can further accelerate by performing the elementwise multiplication of $U$ and $V$ summing only over one axis.

Pitch

Introduce rules such as:

@dispatch
def inv(A: Sum[Product[Dense,Dense], Diagonal], **kwargs):
    ...

@dispatch(cond=product_faster_if_rearranged)
def trace(A: Product):
    ...

Additional context

Plum-dispatch can work a little different than one would expect for parametric types.
Some things need to be spelled out more explicitly (and possibly even changes may need to be made to cola-plum-dispatch)

[Bug] Logdet

🐛 Bug

Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the iterative method, which I believe is triggered after 1e-6.

I replaced this issue by specifying the method="dense" kwarg and seem to have no issues there.

To reproduce

# Jit compiling this function and giving an input that has larger than 1e-6 x 1e-6 shape
jit(lambda: sigma cola.logdet(sigma))( input_matrix_here)
# Here Sigma is a SumLinearOperator of Dense LinOp and Diagonal array.
# This may be an issue on SumLinearOperators.

https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370

--> 189     + cola.logdet(sigma)
[709](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:710)
    190     + diff.T @ cola.solve(sigma, diff)
[710](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:711)
    191 )
[711](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:712)

[712](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:713)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:39, in logdet(A, **kwargs)
[713](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:714)
     17 @export
[714](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:715)
     18 def logdet(A: LinearOperator, **kwargs):
[715](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:716)
     19     r""" Computes logdet of a linear operator. 
[716](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:717)
     20 
[717](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:718)
     21     For large inputs (or with method='iterative'),
[718](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:719)
   (...)
[719](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:720)
     37         Array: logdet
[720](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:721)
     38     """
[721](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:722)
---> 39     _, ld = slogdet(A,**kwargs)
[722](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:723)
     40     return ld
[723](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:724)

[724](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:725)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[725](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:726)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[726](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:727)
    437 logging.info("%s",loginfo)
[727](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:728)
--> 438 return _convert(method(*args,**kw_args), return_type)
[728](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:729)

[729](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:730)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:96, in slogdet(A, **kwargs)
[730](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:731)
     93 elif 'exact' in method or not stochastic_faster:
[731](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:732)
     94     # TODO: explicit autograd rule for this case?
[732](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:733)
     95     logA = cola.linalg.log(A, tol=tol, method='iterative', **kws)
[733](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:734)
---> 96     trlogA = cola.linalg.trace(logA,method='exact',**kws)
[734](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:735)
     97 else:
[735](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:736)
     98     raise ValueError(f"Unknown method {method} or CoLA didn't fit any selection criteria")
[736](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:737)

[737](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:738)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[738](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:739)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[739](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:740)
    437 logging.info("%s",loginfo)
[740](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:741)
--> 438 return _convert(method(*args,**kw_args), return_type)
[741](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:742)

[742](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:743)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/diag_trace.py:137, in trace(A, **kwargs)
[743](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:744)
    117 r""" Compute the trace of a linear operator tr(A).
[744](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:745)
    118 
[745](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:746)
    119 Uses either :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (Hutchinson estimator)
[746](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:747)
   (...)
[747](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:748)
    134 Returns:
[748](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:749)
    135     Array: trace"""
[749](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:750)
    136 assert A.shape[0] == A.shape[1], "Can't trace non square matrix"
[800](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:801)
--> 723   return getattr(self.aval,f"_{name}")(self,*args)
[801](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:802)

[802](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:803)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4153, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
[803](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:804)
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
[804](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:805)
   4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
[805](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:806)
-> 4153 return _gather(arr,treedef,static_idx,dynamic_idx,indices_are_sorted,
[806](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:807)
   4154 unique_indices,mode,fill_value)
[807](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:808)

[808](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:809)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4162, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
[809](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:810)
   4159 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
[810](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:811)
   4160             unique_indices, mode, fill_value):
[811](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:812)
   4161   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
[812](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:813)
-> 4162   indexer = _index_to_gather(shape(arr),idx)  # shared with _scatter_update
[813](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:814)
   4163   y = arr
[814](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:815)
   4165   if fill_value is not None:
[815](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:816)

[816](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:817)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4414, in _index_to_gather(x_shape, idx, normalize_indices)
[817](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:818)
   4405 if not all(_is_slice_element_none_or_constant(elt)
[818](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:819)
   4406            for elt in (start, stop, step)):
[819](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:820)
   4407   msg = ("Array slice indices must have static start/stop/step to be used "
[820](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:821)
   4408          "with NumPy indexing syntax. "
[821](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:822)
   4409          f"Found slice({start}, {stop}, {step}). "
[822](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:823)
   (...)
[823](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:824)
   4412          "dynamic_update_slice (JAX does not support dynamically sized "
[824](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:825)
   4413          "arrays within JIT compiled functions).")
[825](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:826)
-> 4414   raise IndexError(msg)
[826](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:827)
   4415 if not core.is_constant_dim(x_shape[x_axis]):
[827](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:828)
   4416   msg = ("Cannot use NumPy slice indexing on an array dimension whose "
[828](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:829)
   4417          f"size is not statically known ({x_shape[x_axis]}). "
[829](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:830)
   4418          "Try using lax.dynamic_slice/dynamic_update_slice")
[830](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:831)

[831](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:832)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
[832](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:833)

[833](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:834)
Error: Process completed with exit code 1.

System information

Please complete the following information:

  • Latest PyPI version of cola for the above traceback.
  • Also had same issue locally on my machine for the latest commits on main branch.

Additional context

Add any other context about the problem here.

Incorrect use of init inside the unflatten rule of CoLA custom pytrees causes issues [JAX]

When attempting to incorporate a CoLA Linear Operator as a field in Equinox (https://github.com/patrick-kidger/equinox), as shown in the MVE below, I receive AttributeError: 'bool' object has no attribute 'dtype'.

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import optax
import equinox as eqx
import cola

class MyModule(eqx.Module):

    lazy_A : cola.ops.LinearOperator

    def __init__(self, A):
        self.lazy_A = cola.fns.lazify(A)

    def __call__(self, x):
        return self.lazy_A @ x

seed = jr.PRNGKey(0)
A = jr.normal(seed, (10, 10))
X = jnp.ones((10, 1))
model = MyModule(A)
result = eqx.filter(model, eqx.is_inexact_array)

  File "/media/adam/shared_drive/PycharmProjects/test_equinox_lazy_variable/test_equinox_lazy_variable.py", line 24, in <module>
    result = eqx.filter(model, eqx.is_inexact_array)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter
    filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 71, in _filter_tree
    return jtu.tree_map(mask, arg, is_leaf=is_leaf)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 245, in tree_unflatten
    return cls(*new_args, **aux[0])
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 21, in __init__
    super().__init__(dtype=A.dtype, shape=A.shape)
AttributeError: 'bool' object has no attribute 'dtype'

As per patrick-kidger/equinox#453 (comment), the issue appears to be a symptom of a wider bug in CoLA tree unflattening code.

The issue is that they're using init inside the unflatten rule of their pytrees:

return cls(*new_args, **aux[0])

This is a common mistake when implementing custom pytrees in JAX; see the documentation here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

They could fix this by either (a) having their LinearOperator inherit from eqx.Module, or (b) unflattening via new and setattr; see the Equinox implementation here for reference.

Adjoint operations move Jacobian from GPU to CPU

🐛 Bug

The adjoint operations in CoLA are moving the Jacobian tensor from the GPU to the CPU, which can lead to performance issues and inconsistencies.

To reproduce

** Code snippet to reproduce **

import torch
import cola

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

x = torch.randn(100).to(dev)
fn = torch.nn.Sequential(torch.nn.Linear(100, 64), torch.nn.Linear(64, 100)).to(dev)

J = cola.ops.Jacobian(fn, x)
print(J.device, J.T.device, J.H.device, cola.ops.Adjoint(J).device)

** Stack trace/error message **

cuda:0 cpu cpu cpu

Expected Behavior

Output should look like:

cuda:0 cuda:0 cuda:0 cuda:0

System information

Please complete the following information:

  • 0.0.6.dev11+gf3c5494
  • 2.1.2
  • Springdale Open Enterprise Linux 8.6 (Modena)

Additional context

Possibly an issue here
https://github.com/wilson-labs/cola/blob/main/cola/ops/operators.py#L361
where the device is not being used

Possible speedup for linalg.solve

Looking at the code

, linalg.solve is implemented by explicitly calculating the inverse times a vector. Is there a reason not to use something similar to scipy.linalg.solve to achieve this?

import cola
import numpy as np
import scipy as sp
import timeit

A = np.random.randn(2,2)
A = A @ A.T
v = np.random.randn(2)

cola_A = cola.ops.Dense(A)

print(timeit.timeit(lambda: cola.linalg.solve(cola_A, v), number=10000))
print(timeit.timeit(lambda: sp.linalg.solve(A, v), number=10000))

the outputs are

14.14389366703108
0.0880865000654012

I could also try to submit a pull request later. Thanks

[Bug] AttributeError: 'DenseLinearOperator' object has no attribute 'astype' when applying cholesky

I have just tried to upgrade from commit 74406c9, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type
<15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>

    l_zz = lower_cholesky(K)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/plum/function.py", line 438, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/media/adam/shared_drive/PycharmProjects/Process_Shape_Datasets/lower_cholesky.py", line 38, in lower_cholesky
    return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 184, in to_dense
    return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in _matmat
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in <genexpr>
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 25, in _matmat
    return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py", line 168, in cast
    return array.astype(dtype)
AttributeError: 'DenseLinearOperator' object has no attribute 'astype'. Did you mean: 'dtype'?

Normally I would provide a MVE but I am not sure how best to test for this. Let me know how best to assist if required.

Link to Paper?

Hey all,

The README.md references a pre-print paper, but the hyperlink provided seems to just point back to the main GitHub repo rather than link to the actual paper. Is this pre-print available somewhere online to read?

Thanks!

Improving CI

A few ideas:

Minor:

  • Smoke test the example notebooks. I.e., run the example notebooks, fail if any of the cells cause an error. This type of testing doesn't check the output of the notebook cells, but it's better than nothing

Major:

  • Have a matrix strategy that runs the test suite for different base-engine installs. Rather than installing both the stable versions of Jax and PyTorch, you should run CI ~4x times:
    • When only Jax is installed (stable version)
    • When only Jax is installed (unstable/latest version)
    • When only PyTorch is installed (stable version)
    • When only PyTorch is installed (unstable/latest version)
      This grid testing is easy to accomplish with the strategy: matrix: options in the workflow yaml. See the gpytorch CI for an example.

[Feature Request] Efficient Computation of Jacobian products JJ^T

🚀 Feature Request

I recently came across a neat trick to compute Jacobian matrix product of the form JJ^T using VJPs and JVPs. I think this would make for a good addition as a utility in the library, very obvious but probably useful function to have. Let me know what you think!

Motivation

Is your feature request related to a problem? Please describe.

The matrix product JJ^T comes up in probit classification, where the posterior predictive integral of the usual softmax classification is approximated via a linearization of the pre-softmax function. It also comes up is Function-Space VI. I'm sure there are other places it comes up.

The key idea: For a matrix J of size M x N (M outputs, N inputs/parameters), such that M is much smaller than N, it is useful construct the symmetric product JJ^T column-by-column. First constructing a VJP J^T v and then JVP J(J^T v). We would never have a larger M x N matrix, but only a smaller M x M matrix when the vectors v are one-hot vectors of size M.

PyTorch vmap also allows chunking operations which might perhaps be helpful to compute this block-by-block to avoid OOM constraints. Unfortunately, I could not find a chunking operation for jax.vmap but perhaps that can be manually implemented.

Pitch

Describe the solution you'd like

I'd like a utility function be exposed from cola, something to the effect of:

import cola

model = ## pytorch model
f_model = ## a torch.func version of the model, that handles single input and single output.

cola.jjt(f_model)

which computes the JJ^T. There are a few design considerations here, as to how one handles the batches. Batches should perhaps not be within the cola scope, but rather the user's responsibility.

Describe alternatives you've considered

Here's a reference toy implementation in PyTorch to give you a sense of what would need to be implemented.

import torch
import torch.nn as nn
import torch.func as tf

model = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 5))
f_model = lambda p, *args, **kwargs: tf.functional_call(model, p, *args, **kwargs)

inputs = torch.randn(100, 2)
params = dict(model.named_parameters())
outputs_size = torch.Size([5])  ## Only handles 1-D outputs

## Naive computation using tf.jacrev/tf.jacfwd that constructs full Jacobian. vmap over batch size.
j_fn_naive = tf.vmap(
    lambda J_p: torch.cat([v.view(*outputs_size, -1) for _, v in J_p.items()], dim=-1)
)
jjt_fn_naive = tf.vmap(lambda M: M @ M.T)
jjt_naive = jjt_fn_naive(j_fn_naive(tf.jacrev(f_model)(params, inputs)))

## Smarter computation. Outer vmap over batch size, inner vmap over one-hot co-tangent vectors.
jjt_fn = tf.vmap(
    lambda inp: tf.vmap(
        lambda v: tf.jvp(
            lambda p: f_model(p, inp.unsqueeze(0)).squeeze(0),
            (params,),
            tf.vjp(lambda p: f_model(p, inp.unsqueeze(0)).squeeze(0), params)[1](v),
        )[1]
    )(torch.eye(*outputs_size))
)
jjt = jjt_fn(inputs)

print(jjt.shape, jjt_naive.shape)
assert torch.allclose(jjt, jjt_naive)

Are you willing to open a pull request? (We LOVE contributions!!!)

Sure! Once we have some design decisions figured out.

Add releases

It would be good to start adding alpha/beta releases, so that it is easier to pin as a dependency (instead of git commit hash).

[Feature Request] versioning + PyPI pip wheel

🚀 Feature Request

Great work on this library! :)

A couple of requests from the GPJax developers:

(1) PyPI pip wheel for installation.
(2) Using git tags & creating versioned releases on GitHub, with corresponding releases made to PyPI.

Versioning would particularly make it easier to monitor incoming developmental changes.

[Bug] Concatenated Op axis aren't correct

I think there is a bug in the Concatenated Op, where the wrong axes are being processed.

e.g. A of size (100,2), B of size (100,1)

in Cola, if you use the Concatenated op via

C = cola.ops.Concatenated(A, B, axis=0)

C will be of size (200,2)...which isn't correct.

and

C = cola.ops.Concatenated(A, B, axis=1)

will give an Assertion Error: Trying to concatenate matrices of different sizes [(100, 2), (100, 1)]

Add doc to run tests

It would be good to have a small section on how to run the tests. From the workflow it looks like it is simply:

pytest

[Bug] CoLA custom autograd rules do not work with new functorch interface in pytorch.

🐛 Bug

CoLA custom autograd rules do not work with new functorch interface in pytorch.

To reproduce

** Code snippet to reproduce **

import torch
import cola
f = lambda x: x**2

device = torch.device('cuda:0')
def logdet(theta):
    jac = cola.ops.Jacobian(f, theta.to(device))
    D = jac@jac.T
    D = cola.ops.Dense(D.to_dense())
    emax = cola.linalg.eigmax(D.to(torch.device('cpu')),tol=1e-2)
    D = cola.PSD((D+1e-3*emax*cola.ops.I_like(D)))
    logdetJ = cola.linalg.logdet(D, method='iterative-stochastic')
    return logdetJ

theta = torch.randn(10, requires_grad=True)
grad_log_det = torch.func.grad(logdet)(theta)

** Stack trace/error message **
RuntimeError Traceback (most recent call last)
Cell In[3], line 16
13 return logdetJ
15 theta = torch.randn(10, requires_grad=True)
---> 16 grad_log_det = torch.func.grad(logdet)(theta)

File ~/anaconda3/envs/cola/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:1380, in grad..wrapper(*args, **kwargs)
1378 @wraps(func)
1379 def wrapper(*args, **kwargs):
-> 1380 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1381 if has_aux:
1382 grad, (_, aux) = results

File ~/anaconda3/envs/cola/lib/python3.10/site-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks..fn(*args, **kwargs)
36 @functools.wraps(f)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)

File ~/anaconda3/envs/cola/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:1245, in grad_and_value..wrapper(*args, **kwargs)
1242 diff_args = slice_argnums(args, argnums, as_tuple=False)
1243 tree_map
(partial(_create_differentiable, level=level), diff_args)
-> 1245 output = func(*args, **kwargs)
1246 if has_aux:
...
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
515 return custom_function_call(cls, *args, **kwargs)

RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

Expected Behavior

Gradients are computed.

System information

Please complete the following information:

Observation about new Kernel Operator

I notice the new kernel operator makes use of nested for loops and update operations. For JAX that is a very bad idea. for loops should be avoided at all costs.

[Bug] PSD annotation promotion with negative floats.

🐛 Bug PSD annotation promotion with negative floats.

Maybe more of a question about the annotation "mental model", but should the following be expected behaviour? :)

To reproduce

A simple example demonstrating negative floats against a PSD matrix:

import jax.numpy as jnp
from cola.ops import Identity

n = 123
I = Identity(shape=(n, n), dtype=jnp.float64) # has {PSD, Unitary} annotations by default.

negative_I = -1.0 * I # negating so now NSD.

print(negative_I.annotations)  # inspecting annotations.
# {PSD, Unitary}

But negative_I is not PSD.

Expected Behavior

  • Would probably expect NSD annotation or no PSD annotation to be present.

System information

Please complete the following information:

  • CoLA Version (run print(cola.__version__): 0.1.0 / commit 56b913b
  • JaX and/or PyTorch Version (run print(jax.__version__) and/or print(torch.__version__) N/a here.
  • Computer OS: N/a here.

Additional context

Add any other context about the problem here.

Warnings in tests

Running pytest tests/ returns the following warnings, some of which have "invalid value encountered in divide", etc.

========================================================================= warnings summary ==========================================================================
tests/test_linalg.py::test_construct_tridiagonal[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_lanczos.py::test_lanczos_complex[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_lanczos.py::test_lanczos_random[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_lanczos.py::test_lanczos_iter[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_lanczos.py::test_get_lanczos_coeffs[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_lanczos.py::test_construct_tridiagonal[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_stochastic_lanczos_quad.py::test_stochastic_lanczos_quad_random[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/linalg/test_eig.py::test_adjoint[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
  /home/sanyam_s/cola/cola/jax_fns.py:175: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return jnp.array(arr, dtype=dtype)

tests/test_operators.py::test_sparse[<module 'cola.torch_fns' from '/home/sanyam_s/cola/cola/torch_fns.py'>]
  /home/sanyam_s/cola/cola/ops/operators.py:52: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:54.)
    self.A = self.ops.sparse_csr(indptr, indices, data)

tests/test_operators.py::test_householder[<module 'cola.torch_fns' from '/home/sanyam_s/cola/cola/torch_fns.py'>]
  /home/sanyam_s/cola/cola/torch_fns.py:208: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    return torch.tensor(arr, dtype=dtype)

tests/algorithms/test_arnoldi.py::test_arnoldi_eig[<module 'cola.torch_fns' from '/home/sanyam_s/cola/cola/torch_fns.py'>]
  /home/sanyam_s/cola/cola/torch_fns.py:118: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:276.)
    return array.to(dtype)

tests/algorithms/test_arnoldi.py::test_arnoldi_eig[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/linalg/test_eig.py::test_general[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
  /data/users/sanyam_s/.conda/envs/cola/lib/python3.10/site-packages/jax/_src/lax/lax.py:510: ComplexWarning: Casting complex values to real discards the imaginary part
    return _convert_element_type(operand, new_dtype, weak_type=False)

tests/algorithms/test_arnoldi.py::test_householder_arnoldi_matrix[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_arnoldi.py::test_numpy_arnoldi
  /home/sanyam_s/cola/cola/jax_fns.py:175: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return jnp.array(arr, dtype=dtype)

tests/algorithms/test_cg.py::test_cg_track_easy[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/algorithms/test_cg.py::test_cg_easy_case[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
  /home/sanyam_s/cola/cola/jax_fns.py:175: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return jnp.array(arr, dtype=dtype)

tests/algorithms/test_cg.py::test_cg_track_easy[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
  /data/users/sanyam_s/.conda/envs/cola/lib/python3.10/site-packages/jax/_src/api_util.py:172: SyntaxWarning: Jitted function has static_argnums=(0, 3, 4, 5, 6), but only accepts 6 positional arguments. This warning will be replaced by an error after 2022-08-20 at the earliest.
    warnings.warn(f"Jitted function has {argnums_name}={argnums}, "

tests/algorithms/test_cg.py::test_cg_easy_case[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
  /data/users/sanyam_s/.conda/envs/cola/lib/python3.10/site-packages/jax/_src/api_util.py:172: SyntaxWarning: Jitted function has static_argnums=(0, 3, 4, 5, 6, 7, 8), but only accepts 8 positional arguments. This warning will be replaced by an error after 2022-08-20 at the earliest.
    warnings.warn(f"Jitted function has {argnums_name}={argnums}, "

tests/linalg/test_diagonal.py::test_approx_diag[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/linalg/test_diagonal.py::test_large_trace[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>, 'approx']
  /home/sanyam_s/cola/cola/algorithms/diagonal_estimation.py:92: RuntimeWarning: invalid value encountered in divide
    mean = diag_sum/(i*bs)

tests/linalg/test_diagonal.py::test_approx_diag[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>]
tests/linalg/test_diagonal.py::test_large_trace[<module 'cola.jax_fns' from '/home/sanyam_s/cola/cola/jax_fns.py'>, 'approx']
  /home/sanyam_s/cola/cola/algorithms/diagonal_estimation.py:93: RuntimeWarning: invalid value encountered in divide
    stderr = xnp.sqrt((diag_sumsq/(i*bs) - mean**2)/(i*bs))

Environment:
Python: 3.10.12
pytest: 7.4.0
pluggy: 1.2.0
pytorch: 2.0.1+cu118
jax: 0.4.13

Add "good first issues" to elicit contributions

One easy way to build a zoo of examples/how-to guides would be to have issues tagged as "Good First Issues", that the authors think would be good additions (and others can create new ones too). I'd be happy to contribute some examples/how-to guides myself once such issues come up.

[Bug] Error when attempting to use vmap with Jacobian operator

If we have some function f, which we want to take the jacobian of with respect to the rows of an array x,

x = jnp.array([1000,3])
test_jac_op = cola.ops.partial(cola.ops.Jacobian, f)

this works fine,
test_jac_op(x[0])
this causes an error
jax.vmap(test_jac_op)(x)
obviously the error listed below is self explanatory, but the operation outlined above seems a likely use case (as its then common to want to multipy these jacobians by a vector).

 File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/linear_util.py", line 203, in call_wrapped
    ans = gen.send(ans)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py", line 635, in _batch_inner
    out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/api.py", line 1260, in <lambda>
    lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/api_util.py", line 409, in flatten_axes
    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 93, in tree_unflatten
    return treedef.unflatten(leaves)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 245, in tree_unflatten
    return cls(*new_args, **aux[0])
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 425, in __init__
    assert len(x.shape) == 1, "x must be a vector"
AttributeError: 'object' object has no attribute 'shape'

Requirements files

It will probably be easier to maintain two different requirements files: requirements.txt and requirements.dev.txt for the production package and development package respectively.

I see that you have a requirements.sh which seems to have commands for different CUDA versions. I think the users can be referred to the respective dependency doc and remove the burden from CoLA to explain those details.

In the final requirements.txt, you can keep un-versioned PyTorch/JAX so that it doesn't interfere with the dependency parsing as long it is installed.

When published to PyPI, the user installs can look like:

pip install cola

and the development installs can look like:

pip install cola[dev]

(this can be achieved by using the extras_require: { "dev": ["pytest", ...] } in the setuptools configuration (just like the install_requires list).

[Bug] Deprecated PyTree Function in Pytorch 2.2

The latest version of PyTorch has deprecated torch.utils._pytree._register_pytree_node in favour of torch.utils._pytree.register_pytree_node, which results in many of the following warnings,

...site-packages/cola/backends/backends.py:75: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.

[Feature Request] Structured Cholesky decomposition for linear operators

🚀 Feature Request

Add efficient Cholesky decomposition rules for structured linear operators.

Motivation

Presently we can use cola.decompositions.cholesky_decomposed to receive a "lazy product" of two upper and lower triangular linear operators.

def cholesky_decomposed(A: LinearOperator):
    """ Performs a cholesky decomposition A=LL* of a linear operator A.
        The returned operator LL* is the same as A, but represented using
        the triangular structure """
    L = Triangular(A.xnp.cholesky(A.to_dense()), lower=True)
    return L @ L.H

This is really neat, however currently only assumes a dense structure.

For e.g., a diagonal linear operator, the lower cholesky decomposition can be more efficiently computed, simply taking the root of the diagonal. Moreover, this would be critical for efficiency e.g., with Kronecker products structures.

Pitch

Describe the solution you'd like

  • Add e.g., a dispatch rule to cholesky_decomposed to provide better efficiency for structured LinearOperator's, with "dense" assumption as a fallback.

Are you willing to open a pull request? (We LOVE contributions!!!)
Happy to help. :)

[Bug] Jax `Sparse` Matrix Example in Documentation Throws Error

🐛 Bug

One of the Jax sparse matrix examples in the documentation (namely https://cola.readthedocs.io/en/latest/package/cola.ops.html#cola.ops.Sparse) throws an error.

To reproduce

** Code snippet to reproduce **

import jax.numpy as jnp
import cola
data = jnp.array([1, 2, 3, 4, 5, 6])
indices = jnp.array([0, 2, 1, 0, 2, 1])
indptr = jnp.array([0, 2, 4, 6])
shape = (3, 3)
op = cola.ops.Sparse(data, indices, indptr, shape)

** Stack trace/error message **

AttributeError: module 'cola.backends.jax_fns' has no attribute 'sparse_csr'

Expected Behavior

That a Sparse matrix is returned.

System information

  • cola version: 0.0.4
  • jax version: 0.4.14
  • OS: Pop!_OS 22.04 LTS

Additional context

First I'd like to say that I think the idea behind the library is really cool and that I can definitely see myself utilising it across a lot of my projects :).

The fix itself should be as simple defining sparse_csr inside of cola.backends.jax_fns (plus adding a unit test, which should probably also be done for any other examples in the documentation which are also lacking unit tests), which I'm happy to do over the next couple of days.

As an aside, does cola intend on supporting sparsity formats other than CSR? I know that both jax and pytorch support COO, CSR, CSC, BSR, and BSC formats (see https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html#other-sparse-data-structures and https://pytorch.org/docs/stable/sparse.html), so I imagine it would make sense to allow users to explicitly specify which sparsity representation they want. Any thoughts on this?

Thanks for any help.

Cheers,
Matt.

[Bug] `numpy` Tests Fail when `jax` is Installed

🐛 Bug

Some of the numpy backend unit tests fail if jax is installed, but pass when jax is not installed (i.e. these tests are 'flaky').

To reproduce

# Set-up
git clone https://github.com/wilson-labs/cola.git
cd cola
python -m venv venv
source venv/bin/activate
pip install -e ".[dev]"
pip install -r docs/requirements.txt
# Run tests without `jax` installed:
pytest -m "numpy"  -k "test_unary"
# Re-run tests with `jax` installed:
pip install jax jaxlib
pytest -m "numpy" -k "test_unary"

Test results when jax is not installed:

============================= test session starts ==============================
platform linux -- Python 3.10.6, pytest-7.4.2, pluggy-1.3.0
rootdir: /home/mabilton/cola
configfile: setup.cfg
plugins: anyio-4.0.0, cov-4.1.0
collected 460 items / 428 deselected / 32 selected

tests/linalg/test_unary.py ......F.........................              [100%]

=================================== FAILURES ===================================
...
=========================== short test summary info ============================
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'exp', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
=========== 1 failed, 31 passed, 428 deselected, 9 warnings in 1.95s ===========

Test results when jax is installed:

============================= test session starts ==============================
platform linux -- Python 3.10.6, pytest-7.4.2, pluggy-1.3.0
rootdir: /home/mabilton/cola
configfile: setup.cfg
plugins: anyio-4.0.0, cov-4.1.0
collected 460 items / 428 deselected / 32 selected

tests/linalg/test_unary.py ..FFF...........FF...FFF....FF..              [100%]

=================================== FAILURES ===================================
...
=========================== short test summary info ============================
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_lowertriangular', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_lowertriangular', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_kronsum', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_dense', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_dense', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_kron', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'exp', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'sqrt', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_blockdiag', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_blockdiag', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
========= 10 failed, 22 passed, 428 deselected, 70 warnings in 15.03s ==========

Note that similar results are observed when pytest -m "numpy" -k "test_get_lu_from_tridiagonal" is run.

Expected Behavior

Ideally, unit tests should run in a predictable and consistent manner, with the result of a given test not depending on which optional dependencies the user may or may not have installed on their machine.

System information

  • Cola version: 0.0.4
  • jax version: 0.4.16
  • OS: Pop!_OS 22.04 LTS

Additional context

I encountered this issue when running the test suite for the first time before starting work on #75. It appears that the current CI workflow doesn't 'pick-up' on this problem because the numpy tests are only executed tests when jax is not installed.

From my own experiments, it seems that the source of the flaky-ness in these numpy tests is that cola.backends.get_library_fns correctly infers the back-end of a numpy array to be numpy_fns when jax is not installed, but incorrectly infers the back-end to be jax_fns when jax is installed. We can see why this occurs by considering the current implementation of get_library_fns:

def get_library_fns(dtype):
    try:
        from jax import numpy as jnp
        if dtype in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128, jnp.int32, jnp.int64]:
            from cola.backends import jax_fns as fns
            return fns
    except ImportError:
        pass
    ...
    if dtype in [np.float32, np.float64, np.complex64, np.complex128, np.int32, np.int64]:
        from cola.backends import np_fns as fns
        return fns
    raise ImportError("No supported array library found")

i.e. get_library_fns will infer the back-end to be jax if jax can be imported and if dtype matches with a jax.numpy type. Unfortunately, it turns out (much to my surprise) that jax.numpy types are basically just aliases for numpy types, which means that Python evaluates jax.numpy and numpy types as equal to one another:

import numpy as np
import jax.numpy as jnp
print(jnp.float32 == np.float32)
# Prints: True

This means get_library_fns will always return jax_fns when provided with a numpy array if jax is installed. Even more surprisingly, the dtype property of a jax.numpy array is not even guaranteed to be a jax.numpy type:

import jax.numpy as jnp
x = jnp.array([1.,2.,3.], dtype=jnp.float32)
print(type(x.dtype))
# Prints: <class 'numpy.dtype[float32]'>

I think these observations illustrate that the 'premise' behind the get_library_fns function (i.e. that you can determine which back-end to use purely based on the dtype property of an array) probably isn't sound.

Proposed Solutions

Two potential fixes come to mind:

  1. Deprecate the get_library_fns function and replace it with a similar function that requires the user to explicitly name the back-end they wish to be returned.
  2. Add an additional flag to get_library_fns to 'force' it to return the numpy backend, even when jax is installed; this flag can then be used during the numpy tests to ensure that they're consistent.

I'm more than happy to work on this issue, but it would be great to hear what others think about all this first. Thanks in advance for any help.

Cheers,
Matt.

[Bug] New interaction issue when Tree Flattening / Equinox Functionality

It appears that the changes to the code over the past couple of days have broken its compatibility with pytree flattening when called via equinox.filter functionality.

    opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 130, in filter
    return jtu.tree_map(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <45x45 Dense with dtype=float64>.

The issue appears to be related if you want to store a linear operator as a variable in an equinox class. Here is a silly MVE,

import equinox as eqx
import jax
import cola

class Linear(eqx.Module):
    weight: cola.ops.LinearOperator #jax.Array 
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = cola.ops.Dense(jax.random.normal(wkey, (out_size, in_size)))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        y = cola.ops.Dense(x) @ cola.ops.Transpose(self.weight) + cola.ops.Dense(self.bias)
        return y.to_dense()
        # return self.weight @ x + self.bias

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)


model = Linear(2, 3, key=jax.random.PRNGKey(0))

eqx.filter(model, eqx.is_inexact_array)

ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <3x2 Dense with dtype=float32>.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.