Coder Social home page Coder Social logo

google-deepmind / kfac-jax Goto Github PK

View Code? Open in Web Editor NEW
225.0 12.0 16.0 693 KB

Second Order Optimization and Curvature Estimation with K-FAC in JAX.

License: Apache License 2.0

Python 99.81% Shell 0.19%
machine-learning optimization bayesian-deep-learning

kfac-jax's Introduction

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX

Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX

CI status docs pypi

KFAC-JAX is a library built on top of JAX for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use implementation of the K-FAC optimizer and curvature estimator.

Installation

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip:

$ pip install git+https://github.com/google-deepmind/kfac-jax

Alternatively, you can install via PyPI:

$ pip install -U kfac-jax

Our examples rely on additional libraries, all of which you can install using:

$ pip install kfac-jax[examples]

Quickstart

Let's take a look at a simple example of training a neural network, defined using Haiku, with the K-FAC optimizer:

import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32")


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1

  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


def model_fn(x):
  """A Haiku MLP model function - three hidden layer network with tanh."""
  return hk.nets.MLP(
    output_sizes=(50, 50, 50, NUM_CLASSES),
    with_bias=True,
    activation=jax.nn.tanh,
  )(x)


# The Haiku transformed model
hk_model = hk.without_apply_rng(hk.transform(model_fn))


def loss_fn(model_params, model_batch):
  """The loss function to optimize."""
  x, y = model_batch
  logits = hk_model.apply(model_params, x)
  loss = jnp.mean(softmax_cross_entropy(logits, y))

  # The optimizer assumes that the function you provide has already added
  # the L2 regularizer to its gradients.
  return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0


# Create the optimizer
optimizer = kfac_jax.Optimizer(
  value_and_grad_func=jax.value_and_grad(loss_fn),
  l2_reg=L2_REG,
  value_func_has_aux=False,
  value_func_has_state=False,
  value_func_has_rng=False,
  use_adaptive_learning_rate=True,
  use_adaptive_momentum=True,
  use_adaptive_damping=True,
  initial_damping=1.0,
  multi_device=False,
)

input_dataset = make_dataset_iterator(128)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
rng, key = jax.random.split(rng)
params = hk_model.init(key, dummy_images)
rng, key = jax.random.split(rng)
opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

Do not stage (jit or pmap) the optimizer

You should not apply jax.jit or jax.pmap to the call to Optimizer.step. This is already done for you automatically by the optimizer class. To control the staging behaviour of the optimizer set the flag multi_device to True for pmap and to False for jit.

Do not stage (jit or pmap) the loss function

The value_and_grad_func argument provided to the optimizer should compute the loss function value and its gradients. Since the optimizer already stages its step function internally, applying jax.jit to value_and_grad_func is NOT recommended. Importantly, applying jax.pmap is WRONG and most likely will lead to errors.

Registering the model loss function

In order for KFAC-JAX to be able to correctly approximate the curvature matrix of the model it needs to know the precise loss function that you want to optimize. This is done via registration with certain functions provided by the library. For instance, in the example above this is done via the call to kfac_jax.register_softmax_cross_entropy_loss, which tells the optimizer that the loss is the standard softmax cross-entropy. If you don't do this you will get an error when you try to call the optimizer. For all supported loss functions please read the documentation.

Important: The optimizer assumes that the loss is averaged over examples in the minibatch. It is crucial that you follow this convention.

Other model function options

Oftentimes, one will want to output some auxiliary statistics or metrics in addition to the loss value. This can already be done in the value_and_grad_func, in which case we follow the same conventions as JAX and expect the output to be (loss, aux), grads. Similarly, the loss function can take an additional function state (batch norm layers usually have this) or an PRNG key (used in stochastic layers). All of these, however, need to be explicitly told to the optimizer via its arguments value_func_has_aux, value_func_has_state and value_func_has_rng.

Verify optimizer registrations

We strongly encourage the user to pay attention to the logging messages produced by the automatic registration system, in order to ensure that it has correctly understood your model. For the example above this looks like this:

==================================================
Graph parameter registrations:
{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',
                    'w': 'Auto[dense_with_bias_3]'},
 'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',
                    'w': 'Auto[dense_with_bias_2]'},
 'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',
                    'w': 'Auto[dense_with_bias_1]'},
 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
                    'w': 'Auto[dense_with_bias_0]'}}
==================================================

As can be seen from this message, the library has correctly detected all parameters of the model to be part of dense layers.

Further reading

For a high level overview of the optimizer, the different curvature approximations, and the supported layers, please see the documentation.

Citing KFAC-JAX

To cite this repository:

@software{kfac-jax2022github,
  author = {Aleksandar Botev and James Martens},
  title = {{KFAC-JAX}},
  url = {https://github.com/google-deepmind/kfac-jax},
  version = {0.0.2},
  year = {2022},
}

In this bibtex entry, the version number is intended to be from kfac_jax/__init__.py, and the year corresponds to the project's open-source release.

kfac-jax's People

Contributors

botev avatar chsigg avatar fabianp avatar hawkinsp avatar hbq1 avatar james-martens avatar joeljennings avatar rchen152 avatar sauravmaheshkar avatar sharadmv avatar superbobry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

kfac-jax's Issues

Value functions returning state objects not supported

Hi - thank you for the great project!

I'm working with models which carry an internal state, so their forward pass functions return an extra state object. This causes a crash from KFAC-JAX, because that state object is only accounted for at the input, rather than the output.

In optimizer.convert_value_and_grad_to_value_func, the flag has_aux is used to decide whether to return the value function's output directly or take its index-0 element. As the docstring says, this is similar behaviour to jax.grad(), but the flag really refers to any extra output given by the value function, not just an aux dictionary. This hits a snag where the function is called in the Optimizer constructor (optimizer.py line 358), because only value_func_has_aux is considered, so value_func_has_state doesn't cause the index-0 behaviour like I think it should.

I've fixed this locally by changing that call to use has_aux=value_func_has_aux or value_func_has_state - I haven't come across any other problems with state outputs, but I guess they might still exist. This is the patch file I use for the 0.0.3 release:
kfac_jax.txt - I can submit it as a PR if that's helpful.

TypeError: 'ShapedArray' object is not iterable

Hi,

I tried to run the example code, but the code stops at primal_output = self.bind(*arg_values, **kwargs), and returns the error "TypeError: 'ShapedArray' object is not iterable". Could you please help me to solve this problem? Thanks.

Add Support for KFAC Optimization in LSTM and GRU Layers

Feature

I kindly request the addition of support for the Kronecker-Factored Approximate Curvature (KFAC) optimization technique in LSTM and GRU layers within the existing KFAC Optimizer. Currently, most of the KFAC Optimizer classes are tailored for linear and 2D convolution layers. Extending its capabilities to encompass RNN layers would be a significant enhancement.

Proposal

The proposal entails integrating KFAC optimization support for LSTM and GRU layers into the KFAC optimizer. This would involve adapting the KFAC Optimizer to calculate the requisite statistics and computation of chain-structured linear Gaussian graphical model for LSTM and GRU layers which I could not find any public implementation of it.

Motivation

LSTM and GRU layers are foundational components in dealing with sequential data, and time-series analysis. I wonder how much KFAC can significantly improve model training using LSTM and GRU layers by providing accurate approximations of the Fisher information matrix? By integrating support for LSTM and GRU layers within the KFAC Optimizer, researchers would gain the ability to apply the KFAC optimization technique to a wider array of models, including reinforcement learning algorithms.

Additional Context

I have full confidence that the repository maintainers, particularly the first author of the paper titled

I appreciate your consideration of this feature request. Thank you.

Using K-FAC with physics-based losses

Hey,

Thank you for the implementation.

From the guide, I saw that I have to register loss functions to be able to use K-FAC.
For my specific case, the loss function is a FEM simulation on the outputs of the network along with some other functions (postprocessing, filtering etc).

Will it be possible to use K-FAC?

Can this be used for Laplace approximation?

In Laplace approximation, the Hessian of the loss function is computed for quadratic approximation. Can this package be used to do a block-diagonal approximation of the Hessian at the minimum? If yes, could you please show (using jax and flax) how to approximate it and define a quadratic approximation of the loss function (which should be something like 1/2 (theta - theta_star)^T H(L)(theta_star) (theta - theta_star), where theta_star is the minimum and H(L) is the Hessian of the loss function)?

Using kfac inside jitted function

Dear All,

in the problem I am working on, I would need to use Optimizer.step inside a broader function that is jitted for performance reasons. Is there some canonical way, how to do this without dramatically hurting performance and/or inducing some error?

Best,
Honza

KFAC Norm Constraint

Hi,

In the documentation of applying a norm constraint to the update gradient, it says:

norm_constraint: Scalar. If specified, the update is scaled down so that
        its approximate squared Fisher norm ``v^T F v`` is at most the specified
        value. (Note that here ``F`` is the approximate curvature matrix, not
        the exact.)

and the corresponding part of the code:

preconditioned_grads = self.estimator.multiply_inverse(
        state=state.estimator_state,
        parameter_structured_vector=grads,
        identity_weight=self.l2_reg + damping,
        exact_power=self._use_exact_inverses,
        use_cached=self._use_cached_inverses,
        pmap_axis_name=self.pmap_axis_name,
    )
if self._norm_constraint is not None:

      assert not self._use_adaptive_learning_rate
      assert coefficient is not None

      sq_norm_grads = utils.inner_product(preconditioned_grads, grads)

      sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2

      max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_scaled_grads)
      coefficient = jnp.minimum(max_coefficient, 1)

      preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient)

However, as far as I am aware, the preconditioned_grads is F^-1 v, so the sq_norm_grads is actually computing v^T F^-1 v instead of v^T F v as documented. Did I understand it correctly or it is intended to be like so?

How to use kfac to train two probabilistic models jointly?

In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function.

I am wondering what would be the recommended pattern of using kfac ?
More specifically, does it make sense to invoke kfac_jax.register_normal_predictive_distribution twice (for the two probabilistic models respectively) ?

Thanks in advance!

Quick question on "layer_tag_vjp"

Hey KFAC team,

First of all, thanks a lot for this awesome project and all the hard work!

Got a quick question on the implementation of _layer_tag_vjp function in "tracer.py". The version is 0.0.3.

For the returned function vjp_func, my understanding is that it reads primal and tangent value of the "layer inputs" from previously constructed information, specifically primals_dict and tangents_dict. My questions are:

  1. For "primal_dict", it not only contains info for all the layer input, but also contains info for the input of the whole jaxpr. See here. It seems to me that the latter info is not needed here since we are only retrieving info for layer inputs in vjp_func. So is the latter info (about the input of the whole jaxpr) really necessary here, and why?
  2. For "tangents_dict", it constructed from the aux_vjp. The implementation reads:
    all_tangents = aux_vjp(tangents)
    tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
    inputs_tangents = jax.tree_util.tree_leaves(inputs_tangents)
    tangents_dict.update(zip(processed_jaxpr.jaxpr.invars, inputs_tangents))

Here aux_vjp is the vjp function for forward_aux (See here). Therefore the output of aux_vjp, namely all_tangents should have the same structure as the input of forward_aux. Here forward_aux only has a single argument, and thus all_tangents should be just a tuple with a single element. If that's the case, then inputs_tangents is always empty, and we can simplify the implementation to

tangents_dict, = aux_vjp(tangents)

Am I missing anything here? Or in which case will we have a non-empty inputs_tangents?

Quickstart example with different NN libraries does not tag correctly

Hey,

I am adapting the quickstart example to equinox to use it in my project. However, it seems only the bias is correctly tagged (with Auto[scale_and_shift_tag_0]); the weight matrix is tagged as 'Orphan'. I also tried it with flax, in which case only the weight matrix is tagged correctly (with Auto[dense_tag_0]), but the bias is tagged as 'Orphan'. And lastly, for pure jax it seems again only to tag the bias correctly.

I am not sure what I am doing wrong here; I introduced minimal changes. In the test script below the four different libraries can be switched between using the lib_type variable in the code. I would appreciate any help.

import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax
from absl import logging
import sys
import equinox as eqx
import flax
import jax.random as random

logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 5
NUM_FEATURES = 20
rng = jax.random.PRNGKey(42)


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, NUM_FEATURES]), jnp.ones([batch_size], dtype="int32")


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1

  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


lib_type = 'hk' # 'hk', 'eqx', 'flax' 'jax'
#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
if lib_type == 'hk':
  def model_fn(x):
    return hk.nets.MLP(
      output_sizes=(50, 50, NUM_CLASSES),
      with_bias=True,
      activation=jax.nn.tanh,
    )(x)

  hk_model = hk.without_apply_rng(hk.transform(model_fn))

  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = hk_model.apply(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  params = hk_model.init(key, dummy_images)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'eqx':
  class simple_net(eqx.Module):
      net: callable

      def __init__(self, key=None):
        keys = jax.random.split(key,  2)
        self.net = eqx.nn.MLP(NUM_FEATURES, NUM_CLASSES, 100, 0, activation=jax.nn.tanh, key=keys[1])

      def __call__(self, x):
        return self.net(x)

  eqx_model = simple_net(rng)
  params, static = eqx.partition(eqx_model, eqx.is_inexact_array)


  def loss_fn(model_params, static, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    eqx_model = eqx.combine(model_params, static)
    logits = jax.vmap(eqx_model)(x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0
  kfac_loss_fn = lambda params, batch: loss_fn(params, static, batch)

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(kfac_loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'flax':
  class MLP(flax.linen.Module):

    def setup(self):
      self.dense1 = flax.linen.Dense(32)
      self.dense2 = flax.linen.Dense(NUM_CLASSES)

    def __call__(self, x):
      x = self.dense1(x)
      x = flax.linen.relu(x)
      x = self.dense2(x)
      return x


  flax_model = MLP()
  params = flax_model.init(rng, jnp.zeros([128, NUM_FEATURES]))

  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = flax_model.apply(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'jax':
  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))


  # Initialize all layers for a fully-connected neural network with sizes "sizes"
  def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

  def relu(x):
    return jnp.maximum(0, x)

  def predict(params, x):
    activations = x
    for w, b in params[:-1]:
      outputs = jnp.dot(w, activations) + b
      activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits


  params = init_network_params([NUM_FEATURES, 100, NUM_CLASSES], random.PRNGKey(0))


  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = jax.vmap(predict, in_axes=(None, 0))(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))


# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

Compatibility with `pallas` attention

First, thank you for creating and releasing this invaluable resource.

What I am trying to do

I would like to combine kfax-jax with fused attention from pallas.

As far as I understand, this should theoretically be trivial: the attention operation itself contains no tagged parameters and no tagged losses, so KFAC should simply propagate the forward and backward passes.

Reproducing failure with the current repo

Environment and set-up

  • CUDA 12.2
  • CuDNN 8.9
  • Python 3.11
  • jax==0.4.28
  • jaxlib==0.4.28+cuda12.cudnn89
  • dm-haiku==0.0.12
  • kfac-jax cloned from main

Reproduction

scripts/kfac.py

import haiku as hk
import jax
import jax.experimental.pallas.ops.attention as attention
import jax.numpy as jnp
import kfac_jax


if __name__ == "__main__":

    def model(inputs):
        # shape [batch, nodes, features]
        k = hk.Linear(16 * 16)(inputs).reshape((*inputs.shape[:-1], 16, 16))
        attended = attention.mha(k, k, k, None)
        # reduce to shape batch
        y_hat = attended.mean([-1, -2, -3])
        return y_hat

    # The Haiku transformed model
    hk_model = hk.without_apply_rng(hk.transform(model))

    def loss_fn(model_params, model_batch):
        """The loss function to optimize."""
        x, y = model_batch
        preds = hk_model.apply(model_params, x)
        errs = (y - preds) ** 2
        kfac_jax.register_normal_predictive_distribution(errs)
        loss = jnp.mean(errs)

        return loss

    x = jnp.zeros((16, 16, 16))
    y = jnp.zeros(16)

    rng = jax.random.PRNGKey(42)
    rng, rng_init = jax.random.split(rng)
    params = hk_model.init(rng_init, x)

    # KFAC

    # Create the optimizer
    optimizer = kfac_jax.Optimizer(
        value_and_grad_func=jax.value_and_grad(loss_fn),
        l2_reg=0.0,
        value_func_has_aux=False,
        value_func_has_state=False,
        value_func_has_rng=False,
        use_adaptive_learning_rate=True,
        use_adaptive_momentum=False,
        use_adaptive_damping=True,
        initial_damping=1.0,
        multi_device=False,
    )

    rng, rng_opt = jax.random.split(rng)
    opt_state = optimizer.init(params, rng_opt, (x, y))
    params, opt_state, stats = optimizer.step(
        params, opt_state, rng, batch=(x, y), global_step_int=0, momentum=0
    )

This fails with the following error message

Traceback (most recent call last):
  File "./scripts/kfac.py", line 167, in <module>
    opt_state = optimizer.init(params, rng_opt, (x, y))
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1023, in init
    return self._init(params, rng, batch, func_state)
  File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
    outs = jitted_func(instance, *args)
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 988, in _init
    estimator_state=self.estimator.init(
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1182, in init
    self.finalize(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 266, in finalize
    self._finalize(*args, **kwargs)
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1167, in _finalize
    self._jaxpr = self._jaxpr_extractor(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 459, in get_processed_jaxpr
    closed_jaxpr, _ = retrieve(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 425, in retrieve
    processed_jaxpr = ProcessedJaxpr.make_from_func(
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 314, in make_from_func
    func = tgm.auto_register_tags(
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 1614, in auto_register_tags
    graph = make_jax_graph(
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 336, in make_jax_graph
    closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args)
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1633, in value_func
    out, _ = value_and_grad_func(*args, **kwargs)
  File "./scripts/kfac.py", line 25, in loss_fn
    preds = hk_model.apply(model_params, x)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/multi_transform.py", line 314, in apply_fn
    return f.apply(params, None, *args, **kwargs)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  File "./scripts/kfac.py", line 13, in model
    attended = attention.mha(k, k, k, None)
  File "/opt/env/lib/python3.11/site-packages/jax/experimental/pallas/ops/attention.py", line 287, in _mha_forward
    out, l, m = pl.pallas_call(
  File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 589, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
 Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
  _split_transpose=False
  jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
      k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
      n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
      u:i32[] = add n 1
      v:i32[] = mul n 16
      w:f32[16,16] <- h[v:v+16,:]
      x:f32[16,16] <- k[v:v+16,:]
      y:f32[16,16] = transpose[permutation=(1, 0)] w
      z:f32[16,16] = transpose[permutation=(1, 0)] x
      ba:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i y
      bb:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] l y
      bc:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i z
      bd:f32[16,16] = add_any bb bc
      be:f32[16] = reduce_max[axes=(1,)] ba
      bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
      bg:bool[16,16] = eq ba bf
      bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
      bi:f32[16] = reduce_sum[axes=(1,)] bh
      bj:f32[16,16] = mul bd bh
      bk:f32[16] = reduce_sum[axes=(1,)] bj
      bl:f32[16] = div bk bi
      bm:f32[16] = max p be
      bn:bool[16] = eq p bm
      bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      bq:f32[16] = select_n bn bp bo
      br:bool[16] = eq be bm
      bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bu:f32[16] = select_n br bt bs
      bv:f32[16] = div bq bu
      bw:f32[16] = mul s bv
      bx:bool[16] = eq be bm
      by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      ca:f32[16] = select_n bx bz by
      cb:bool[16] = eq p bm
      cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      ce:f32[16] = select_n cb cd cc
      cf:f32[16] = div ca ce
      cg:f32[16] = mul bl cf
      ch:f32[16] = add_any bw cg
      ci:f32[16] = sub p bm
      cj:f32[16] = sub s ch
      ck:f32[16] = exp ci
      cl:f32[16] = mul cj ck
      cm:f32[16] = mul ck q
      cn:f32[16] = mul cl q
      co:f32[16] = mul ck t
      cp:f32[16] = add_any cn co
      cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
      cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
      cs:f32[16,16] = sub ba cq
      ct:f32[16,16] = sub bd cr
      cu:f32[16,16] = exp cs
      cv:f32[16,16] = mul ct cu
      cw:f32[16] = reduce_sum[axes=(1,)] cu
      cx:f32[16] = reduce_sum[axes=(1,)] cv
      cy:f32[16] = add cm cw
      cz:f32[16] = add cp cx
      da:f32[16] = div 1.0 cy
      db:f32[16] = neg cz
      dc:f32[16] = mul db 1.0
      dd:f32[16] = integer_pow[y=-2] cy
      de:f32[16] = mul dc dd
      df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
      dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
      dh:f32[16,16] = mul cu df
      di:f32[16,16] = mul cv df
      dj:f32[16,16] = mul cu dg
      dk:f32[16,16] = add_any di dj
      dl:f32[16] = mul cm da
      dm:f32[16] = mul cp da
      dn:f32[16] = mul cm de
      do:f32[16] = add_any dm dn
      dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
      dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
      dr:f32[16,16] = mul dp o
      ds:f32[16,16] = mul dq o
      dt:f32[16,16] = mul dp r
      du:f32[16,16] = add_any ds dt
      dv:f32[16,16] <- j[v:v+16,:]
      dw:f32[16,16] <- m[v:v+16,:]
      dx:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dv
      dy:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dk dv
      dz:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dw
      ea:f32[16,16] = add_any dy dz
      eb:f32[16,16] = add dr dx
      ec:f32[16,16] = add du ea
    in (u, eb, bm, cy, ec, ch, cz) }
  length=1
  linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
  num_carry=7
  num_consts=6
  reverse=False
  unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep

 Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
    k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
    n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
    q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
    t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
    v:i32[] = program_id[axis=0] 
    w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
    x:f32[16] = sub w b
    y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
    z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
    ba:i32[] = mul v e
    bb:f32[16,16] <- j[ba:ba+16,:]
    bc:f32[16,16] <- p[ba:ba+16,:]
    bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
    be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
    bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
    bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
    bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
    bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
          bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
          bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
          cd:i32[] = add bw 1
          ce:i32[] = mul bw 16
          cf:f32[16,16] <- bq[ce:ce+16,:]
          cg:f32[16,16] <- bt[ce:ce+16,:]
          ch:f32[16,16] = transpose[permutation=(1, 0)] cf
          ci:f32[16,16] = transpose[permutation=(1, 0)] cg
          cj:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ch
          ck:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] bu ch
          cl:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ci
          cm:f32[16,16] = add_any ck cl
          cn:f32[16] = reduce_max[axes=(1,)] cj
          co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
          cp:bool[16,16] = eq cj co
          cq:f32[16,16] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] cp
          cr:f32[16] = reduce_sum[axes=(1,)] cq
          cs:f32[16,16] = mul cm cq
          ct:f32[16] = reduce_sum[axes=(1,)] cs
          cu:f32[16] = div ct cr
          cv:f32[16] = max by cn
          cw:bool[16] = eq by cv
          cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          cz:f32[16] = select_n cw cy cx
          da:bool[16] = eq cn cv
          db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dd:f32[16] = select_n da dc db
          de:f32[16] = div cz dd
          df:f32[16] = mul cb de
          dg:bool[16] = eq cn cv
          dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          dj:f32[16] = select_n dg di dh
          dk:bool[16] = eq by cv
          dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dn:f32[16] = select_n dk dm dl
          do:f32[16] = div dj dn
          dp:f32[16] = mul cu do
          dq:f32[16] = add_any df dp
          dr:f32[16] = sub by cv
          ds:f32[16] = sub cb dq
          dt:f32[16] = exp dr
          du:f32[16] = mul ds dt
          dv:f32[16] = mul dt bz
          dw:f32[16] = mul du bz
          dx:f32[16] = mul dt cc
          dy:f32[16] = add_any dw dx
          dz:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] cv
          ea:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] dq
          eb:f32[16,16] = sub cj dz
          ec:f32[16,16] = sub cm ea
          ed:f32[16,16] = exp eb
          ee:f32[16,16] = mul ec ed
          ef:f32[16] = reduce_sum[axes=(1,)] ed
          eg:f32[16] = reduce_sum[axes=(1,)] ee
          eh:f32[16] = add dv ef
          ei:f32[16] = add dy eg
          ej:f32[16] = div 1.0 eh
          ek:f32[16] = neg ei
          el:f32[16] = mul ek 1.0
          em:f32[16] = integer_pow[y=-2] eh
          en:f32[16] = mul el em
          eo:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ej
          ep:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] en
          eq:f32[16,16] = mul ed eo
          er:f32[16,16] = mul ee eo
          es:f32[16,16] = mul ed ep
          et:f32[16,16] = add_any er es
          eu:f32[16] = mul dv ej
          ev:f32[16] = mul dy ej
          ew:f32[16] = mul dv en
          ex:f32[16] = add_any ev ew
          ey:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] eu
          ez:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ex
          fa:f32[16,16] = mul ey bx
          fb:f32[16,16] = mul ez bx
          fc:f32[16,16] = mul ey ca
          fd:f32[16,16] = add_any fb fc
          fe:f32[16,16] <- bs[ce:ce+16,:]
          ff:f32[16,16] <- bv[ce:ce+16,:]
          fg:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq fe
          fh:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] et fe
          fi:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq ff
          fj:f32[16,16] = add_any fh fi
          fk:f32[16,16] = add fa fg
          fl:f32[16,16] = add fd fj
        in (cd, fk, cv, eh, fl, dq, ei) }
      length=1
      linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
      num_carry=7
      num_consts=6
      reverse=False
      unroll=1
    ] k bb l q bc r i z x y be bg bi
    fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
    fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
    fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
    fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
    fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
    fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
  in () }

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "./scripts/kfac.py", line 168, in <module>
    params, opt_state, stats = optimizer.step(
                               ^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1339, in step
    return self._step(params, state, rng, batch, func_state, learning_rate, momentum, damping)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
    outs = jitted_func(instance, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1130, in _step
    state = self._maybe_update_estimator_curvature(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 783, in _maybe_update_estimator_curvature
    return self._maybe_update_estimator_state(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 735, in _maybe_update_estimator_state
    state.estimator_state = lax.cond(
                            ^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 755, in _update_estimator_curvature
    state = self.estimator.update_curvature_matrix_estimate(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1422, in update_curvature_matrix_estimate
    losses, losses_vjp = self._compute_losses_vjp(func_args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1106, in _compute_losses_vjp
    return self._vjp(func_args)
           ^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 456, in wrapped_transformation
    return f(func_args, *args)
           ^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 871, in _layer_tag_vjp
    _, aux_vjp, losses_inputs = jax.vjp(forward_aux, aux_dict, has_aux=True)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 833, in forward_aux
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 72, in eval_jaxpr_eqn
    output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 257, in _pallas_call_jvp_rule
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
 Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
  _split_transpose=False
  jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
      k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
      n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
      u:i32[] = add n 1
      v:i32[] = mul n 16
      w:f32[16,16] <- h[v:v+16,:]
      x:f32[16,16] <- k[v:v+16,:]
      y:f32[16,16] = transpose[permutation=(1, 0)] w
      z:f32[16,16] = transpose[permutation=(1, 0)] x
      ba:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i y
      bb:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] l y
      bc:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i z
      bd:f32[16,16] = add_any bb bc
      be:f32[16] = reduce_max[axes=(1,)] ba
      bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
      bg:bool[16,16] = eq ba bf
      bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
      bi:f32[16] = reduce_sum[axes=(1,)] bh
      bj:f32[16,16] = mul bd bh
      bk:f32[16] = reduce_sum[axes=(1,)] bj
      bl:f32[16] = div bk bi
      bm:f32[16] = max p be
      bn:bool[16] = eq p bm
      bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      bq:f32[16] = select_n bn bp bo
      br:bool[16] = eq be bm
      bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bu:f32[16] = select_n br bt bs
      bv:f32[16] = div bq bu
      bw:f32[16] = mul s bv
      bx:bool[16] = eq be bm
      by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      ca:f32[16] = select_n bx bz by
      cb:bool[16] = eq p bm
      cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      ce:f32[16] = select_n cb cd cc
      cf:f32[16] = div ca ce
      cg:f32[16] = mul bl cf
      ch:f32[16] = add_any bw cg
      ci:f32[16] = sub p bm
      cj:f32[16] = sub s ch
      ck:f32[16] = exp ci
      cl:f32[16] = mul cj ck
      cm:f32[16] = mul ck q
      cn:f32[16] = mul cl q
      co:f32[16] = mul ck t
      cp:f32[16] = add_any cn co
      cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
      cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
      cs:f32[16,16] = sub ba cq
      ct:f32[16,16] = sub bd cr
      cu:f32[16,16] = exp cs
      cv:f32[16,16] = mul ct cu
      cw:f32[16] = reduce_sum[axes=(1,)] cu
      cx:f32[16] = reduce_sum[axes=(1,)] cv
      cy:f32[16] = add cm cw
      cz:f32[16] = add cp cx
      da:f32[16] = div 1.0 cy
      db:f32[16] = neg cz
      dc:f32[16] = mul db 1.0
      dd:f32[16] = integer_pow[y=-2] cy
      de:f32[16] = mul dc dd
      df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
      dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
      dh:f32[16,16] = mul cu df
      di:f32[16,16] = mul cv df
      dj:f32[16,16] = mul cu dg
      dk:f32[16,16] = add_any di dj
      dl:f32[16] = mul cm da
      dm:f32[16] = mul cp da
      dn:f32[16] = mul cm de
      do:f32[16] = add_any dm dn
      dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
      dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
      dr:f32[16,16] = mul dp o
      ds:f32[16,16] = mul dq o
      dt:f32[16,16] = mul dp r
      du:f32[16,16] = add_any ds dt
      dv:f32[16,16] <- j[v:v+16,:]
      dw:f32[16,16] <- m[v:v+16,:]
      dx:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dv
      dy:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dk dv
      dz:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dw
      ea:f32[16,16] = add_any dy dz
      eb:f32[16,16] = add dr dx
      ec:f32[16,16] = add du ea
    in (u, eb, bm, cy, ec, ch, cz) }
  length=1
  linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
  num_carry=7
  num_consts=6
  reverse=False
  unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep

 Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
    k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
    n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
    q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
    t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
    v:i32[] = program_id[axis=0] 
    w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
    x:f32[16] = sub w b
    y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
    z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
    ba:i32[] = mul v e
    bb:f32[16,16] <- j[ba:ba+16,:]
    bc:f32[16,16] <- p[ba:ba+16,:]
    bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
    be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
    bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
    bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
    bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
    bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
          bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
          bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
          cd:i32[] = add bw 1
          ce:i32[] = mul bw 16
          cf:f32[16,16] <- bq[ce:ce+16,:]
          cg:f32[16,16] <- bt[ce:ce+16,:]
          ch:f32[16,16] = transpose[permutation=(1, 0)] cf
          ci:f32[16,16] = transpose[permutation=(1, 0)] cg
          cj:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ch
          ck:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] bu ch
          cl:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ci
          cm:f32[16,16] = add_any ck cl
          cn:f32[16] = reduce_max[axes=(1,)] cj
          co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
          cp:bool[16,16] = eq cj co
          cq:f32[16,16] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] cp
          cr:f32[16] = reduce_sum[axes=(1,)] cq
          cs:f32[16,16] = mul cm cq
          ct:f32[16] = reduce_sum[axes=(1,)] cs
          cu:f32[16] = div ct cr
          cv:f32[16] = max by cn
          cw:bool[16] = eq by cv
          cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          cz:f32[16] = select_n cw cy cx
          da:bool[16] = eq cn cv
          db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dd:f32[16] = select_n da dc db
          de:f32[16] = div cz dd
          df:f32[16] = mul cb de
          dg:bool[16] = eq cn cv
          dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          dj:f32[16] = select_n dg di dh
          dk:bool[16] = eq by cv
          dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dn:f32[16] = select_n dk dm dl
          do:f32[16] = div dj dn
          dp:f32[16] = mul cu do
          dq:f32[16] = add_any df dp
          dr:f32[16] = sub by cv
          ds:f32[16] = sub cb dq
          dt:f32[16] = exp dr
          du:f32[16] = mul ds dt
          dv:f32[16] = mul dt bz
          dw:f32[16] = mul du bz
          dx:f32[16] = mul dt cc
          dy:f32[16] = add_any dw dx
          dz:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] cv
          ea:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] dq
          eb:f32[16,16] = sub cj dz
          ec:f32[16,16] = sub cm ea
          ed:f32[16,16] = exp eb
          ee:f32[16,16] = mul ec ed
          ef:f32[16] = reduce_sum[axes=(1,)] ed
          eg:f32[16] = reduce_sum[axes=(1,)] ee
          eh:f32[16] = add dv ef
          ei:f32[16] = add dy eg
          ej:f32[16] = div 1.0 eh
          ek:f32[16] = neg ei
          el:f32[16] = mul ek 1.0
          em:f32[16] = integer_pow[y=-2] eh
          en:f32[16] = mul el em
          eo:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ej
          ep:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] en
          eq:f32[16,16] = mul ed eo
          er:f32[16,16] = mul ee eo
          es:f32[16,16] = mul ed ep
          et:f32[16,16] = add_any er es
          eu:f32[16] = mul dv ej
          ev:f32[16] = mul dy ej
          ew:f32[16] = mul dv en
          ex:f32[16] = add_any ev ew
          ey:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] eu
          ez:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ex
          fa:f32[16,16] = mul ey bx
          fb:f32[16,16] = mul ez bx
          fc:f32[16,16] = mul ey ca
          fd:f32[16,16] = add_any fb fc
          fe:f32[16,16] <- bs[ce:ce+16,:]
          ff:f32[16,16] <- bv[ce:ce+16,:]
          fg:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq fe
          fh:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] et fe
          fi:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq ff
          fj:f32[16,16] = add_any fh fi
          fk:f32[16,16] = add fa fg
          fl:f32[16,16] = add fd fj
        in (cd, fk, cv, eh, fl, dq, ei) }
      length=1
      linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
      num_carry=7
      num_consts=6
      reverse=False
      unroll=1
    ] k bb l q bc r i z x y be bg bi
    fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
    fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
    fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
    fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
    fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
    fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
  in () }
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

My own attempts to investigate

We can see that the issue arises when running vjp on the manipulated version of the model. To try and diagnose the issue, I tried a minimal re-implementation of this part of the KFAC algorithm

# With the same model, etc from scripts/kfac.py above

import functools

primal_func_args = [params, (x, y)]

def read_env(
    env,
    variables,
):
    """Reads from the variable-to-array environment during tracing."""
    result = []
    assert isinstance(variables, list)
    for v in variables:
        if isinstance(v, jax.core.Literal):
            # Literals are values baked into the Jaxpr
            result.append(v.val)
        else:
            result.append(env[v])
    return result

def write_env(
    env,
    variables,
    values,
) -> None:
    """Writes to the variable-to-array environment during tracing."""
    assert len(variables) == len(values)
    for variables, val in zip(variables, values):
        env[variables] = val

def eval_jaxpr_eqn(eqn, in_values):
    """Computes the outputs of the given Jaxpr equation."""

    subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)

    output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)

    if not isinstance(output, list):
        return [output]
    else:
        return output

processed_jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(params, (x, y))

layer_input_vars = list(u for eqn in processed_jaxpr.eqns for u in eqn.invars)

def forward():
    """Computes the values of all inputs to all **layer** tags."""

    own_func_args = primal_func_args

    # Mapping from variable -> value
    env = {}
    read = functools.partial(read_env, env)
    write = functools.partial(write_env, env)

    # Bind args and consts to environment
    write(processed_jaxpr.jaxpr.invars, jax.tree_util.tree_leaves(own_func_args))
    write(processed_jaxpr.jaxpr.constvars, processed_jaxpr.consts)

    # Loop through equations and evaluate them
    for eqn in processed_jaxpr.jaxpr.eqns:

        write(eqn.outvars, eval_jaxpr_eqn(eqn, read(eqn.invars)))

    return tuple(read(layer_input_vars))

input_values = forward()

def forward_aux(aux):

    own_func_args = primal_func_args

    # Mapping from variable -> value
    env = {}
    read = functools.partial(read_env, env)

    def write(variables: list[jax.core.Var], values) -> None:
        # if not isinstance(variables, list):
        #   variables = [variables]
        write_env(env, variables, values)

        for v in variables:
            if not isinstance(v, jax.core.Literal) and v in aux:
                env[v] = env[v] + aux[v]

    # Bind args and consts to environment
    write(processed_jaxpr.jaxpr.invars, jax.tree_util.tree_leaves(own_func_args))

    write(processed_jaxpr.jaxpr.constvars, processed_jaxpr.consts)

    # Loop through equations and evaluate primitives using `bind`
    losses_p_dependants = []
    losses_inputs_values = []

    for eqn in processed_jaxpr.jaxpr.eqns:

        input_values = read(eqn.invars)
        out = eval_jaxpr_eqn(eqn, input_values)
        write(eqn.outvars, out)

        losses_inputs_values.append(tuple(input_values))

    return tuple(losses_p_dependants), tuple(losses_inputs_values)

aux_dict = jax.tree_util.tree_map(jnp.zeros_like, input_values)
my_outputs, my_outputs_aux = forward_aux(aux_dict)
_, aux_vjp, losses_inputs = jax.vjp(forward_aux, aux_dict, has_aux=True)
print("It worked.")

This runs without error. This reimplementation was based on my reading and understanding of the kfac-jax codebase and might possibly miss something important. Of course, it is missing the loss and layer tagging part; I had hoped that wasn't relevant.

Request for assistance

I would really appreciate your advice on this task. Specifically

  1. what is the root cause of the current failure?
  2. would it be trivial to fix kfac-jax to work with pallas attention? I would be happy to help work on a fix with guidance on where to look
  3. if it is not trivial, would it be possible to hack kfac-jax to work specifically with the attention operation, assuming that this operation contains no layer tags and no loss tags?

Unpack Error when using KFAC with block-diagonal for Dense networks

Hi,

I was trying to get the example code in the readme working with the BlockDiagonal approximation. The default simply uses the normal diagonal. However, when I try to define my optimizer like this:

opt = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(partial(expected_model_likelihood, l2=0.001)),
    l2_reg=0.001,
    use_adaptive_learning_rate=True,
    use_adaptive_damping=True,
    use_adaptive_momentum=True,
    initial_damping=1.0,
    min_damping= 0.0001,
    layer_tag_to_block_ctor={'generic_tag': kfac_jax.DenseTwoKroneckerFactored},  # Specify the approximation type here
    estimation_mode='ggn_curvature_prop',
    multi_device=False
)

then when I try to use this optimizer I get the following ValueError:

del pmap_axis_name
x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
assert utils.first_dim_is_size(batch_size, x, dy)

ValueError: not enough values to unpack (expected 1, got 0)

Corresponding to the curvature update method in class DenseTwoKroneckerFactored (line 1165) of _src.curvature_blocks.py. The estimation data dictionary is filled with the parameters and parameters-tangents, but I do not understand the codebase sufficiently to grasp why the inputs and outputs_tangent keys are not filled.

In this way I cannot get the actual KFAC of this repo working... Are there perhaps some examples that make use of the DenseTwoKroneckerFactored? As far as I can tell all provided examples simply make use of the diagonal Fisher for optimization, not KFAC. But I may be wrong of course.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.