First, thank you for creating and releasing this invaluable resource.
Traceback (most recent call last):
File "./scripts/kfac.py", line 167, in <module>
opt_state = optimizer.init(params, rng_opt, (x, y))
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1023, in init
return self._init(params, rng, batch, func_state)
File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
outs = jitted_func(instance, *args)
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 988, in _init
estimator_state=self.estimator.init(
File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
return method(instance, *args, **kwargs)
File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1182, in init
self.finalize(func_args)
File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 266, in finalize
self._finalize(*args, **kwargs)
File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1167, in _finalize
self._jaxpr = self._jaxpr_extractor(func_args)
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 459, in get_processed_jaxpr
closed_jaxpr, _ = retrieve(func_args)
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 425, in retrieve
processed_jaxpr = ProcessedJaxpr.make_from_func(
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 314, in make_from_func
func = tgm.auto_register_tags(
File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 1614, in auto_register_tags
graph = make_jax_graph(
File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 336, in make_jax_graph
closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args)
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1633, in value_func
out, _ = value_and_grad_func(*args, **kwargs)
File "./scripts/kfac.py", line 25, in loss_fn
preds = hk_model.apply(model_params, x)
File "/opt/env/lib/python3.11/site-packages/haiku/_src/multi_transform.py", line 314, in apply_fn
return f.apply(params, None, *args, **kwargs)
File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 183, in apply_fn
out, state = f.apply(params, None, *args, **kwargs)
File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 456, in apply_fn
out = f(*args, **kwargs)
File "./scripts/kfac.py", line 13, in model
attended = attention.mha(k, k, k, None)
File "/opt/env/lib/python3.11/site-packages/jax/experimental/pallas/ops/attention.py", line 287, in _mha_forward
out, l, m = pl.pallas_call(
File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 589, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
u:i32[] = add n 1
v:i32[] = mul n 16
w:f32[16,16] <- h[v:v+16,:]
x:f32[16,16] <- k[v:v+16,:]
y:f32[16,16] = transpose[permutation=(1, 0)] w
z:f32[16,16] = transpose[permutation=(1, 0)] x
ba:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] i y
bb:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] l y
bc:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] i z
bd:f32[16,16] = add_any bb bc
be:f32[16] = reduce_max[axes=(1,)] ba
bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
bg:bool[16,16] = eq ba bf
bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
bi:f32[16] = reduce_sum[axes=(1,)] bh
bj:f32[16,16] = mul bd bh
bk:f32[16] = reduce_sum[axes=(1,)] bj
bl:f32[16] = div bk bi
bm:f32[16] = max p be
bn:bool[16] = eq p bm
bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
bq:f32[16] = select_n bn bp bo
br:bool[16] = eq be bm
bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bu:f32[16] = select_n br bt bs
bv:f32[16] = div bq bu
bw:f32[16] = mul s bv
bx:bool[16] = eq be bm
by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
ca:f32[16] = select_n bx bz by
cb:bool[16] = eq p bm
cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
ce:f32[16] = select_n cb cd cc
cf:f32[16] = div ca ce
cg:f32[16] = mul bl cf
ch:f32[16] = add_any bw cg
ci:f32[16] = sub p bm
cj:f32[16] = sub s ch
ck:f32[16] = exp ci
cl:f32[16] = mul cj ck
cm:f32[16] = mul ck q
cn:f32[16] = mul cl q
co:f32[16] = mul ck t
cp:f32[16] = add_any cn co
cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
cs:f32[16,16] = sub ba cq
ct:f32[16,16] = sub bd cr
cu:f32[16,16] = exp cs
cv:f32[16,16] = mul ct cu
cw:f32[16] = reduce_sum[axes=(1,)] cu
cx:f32[16] = reduce_sum[axes=(1,)] cv
cy:f32[16] = add cm cw
cz:f32[16] = add cp cx
da:f32[16] = div 1.0 cy
db:f32[16] = neg cz
dc:f32[16] = mul db 1.0
dd:f32[16] = integer_pow[y=-2] cy
de:f32[16] = mul dc dd
df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
dh:f32[16,16] = mul cu df
di:f32[16,16] = mul cv df
dj:f32[16,16] = mul cu dg
dk:f32[16,16] = add_any di dj
dl:f32[16] = mul cm da
dm:f32[16] = mul cp da
dn:f32[16] = mul cm de
do:f32[16] = add_any dm dn
dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
dr:f32[16,16] = mul dp o
ds:f32[16,16] = mul dq o
dt:f32[16,16] = mul dp r
du:f32[16,16] = add_any ds dt
dv:f32[16,16] <- j[v:v+16,:]
dw:f32[16,16] <- m[v:v+16,:]
dx:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dh dv
dy:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dk dv
dz:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dh dw
ea:f32[16,16] = add_any dy dz
eb:f32[16,16] = add dr dx
ec:f32[16,16] = add du ea
in (u, eb, bm, cy, ec, ch, cz) }
length=1
linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
num_carry=7
num_consts=6
reverse=False
unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep
Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
v:i32[] = program_id[axis=0]
w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
x:f32[16] = sub w b
y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
ba:i32[] = mul v e
bb:f32[16,16] <- j[ba:ba+16,:]
bc:f32[16,16] <- p[ba:ba+16,:]
bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
cd:i32[] = add bw 1
ce:i32[] = mul bw 16
cf:f32[16,16] <- bq[ce:ce+16,:]
cg:f32[16,16] <- bt[ce:ce+16,:]
ch:f32[16,16] = transpose[permutation=(1, 0)] cf
ci:f32[16,16] = transpose[permutation=(1, 0)] cg
cj:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] br ch
ck:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] bu ch
cl:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] br ci
cm:f32[16,16] = add_any ck cl
cn:f32[16] = reduce_max[axes=(1,)] cj
co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
cp:bool[16,16] = eq cj co
cq:f32[16,16] = convert_element_type[
new_dtype=float32
weak_type=False
] cp
cr:f32[16] = reduce_sum[axes=(1,)] cq
cs:f32[16,16] = mul cm cq
ct:f32[16] = reduce_sum[axes=(1,)] cs
cu:f32[16] = div ct cr
cv:f32[16] = max by cn
cw:bool[16] = eq by cv
cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
cz:f32[16] = select_n cw cy cx
da:bool[16] = eq cn cv
db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
dd:f32[16] = select_n da dc db
de:f32[16] = div cz dd
df:f32[16] = mul cb de
dg:bool[16] = eq cn cv
dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
dj:f32[16] = select_n dg di dh
dk:bool[16] = eq by cv
dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
dn:f32[16] = select_n dk dm dl
do:f32[16] = div dj dn
dp:f32[16] = mul cu do
dq:f32[16] = add_any df dp
dr:f32[16] = sub by cv
ds:f32[16] = sub cb dq
dt:f32[16] = exp dr
du:f32[16] = mul ds dt
dv:f32[16] = mul dt bz
dw:f32[16] = mul du bz
dx:f32[16] = mul dt cc
dy:f32[16] = add_any dw dx
dz:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] cv
ea:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] dq
eb:f32[16,16] = sub cj dz
ec:f32[16,16] = sub cm ea
ed:f32[16,16] = exp eb
ee:f32[16,16] = mul ec ed
ef:f32[16] = reduce_sum[axes=(1,)] ed
eg:f32[16] = reduce_sum[axes=(1,)] ee
eh:f32[16] = add dv ef
ei:f32[16] = add dy eg
ej:f32[16] = div 1.0 eh
ek:f32[16] = neg ei
el:f32[16] = mul ek 1.0
em:f32[16] = integer_pow[y=-2] eh
en:f32[16] = mul el em
eo:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] ej
ep:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] en
eq:f32[16,16] = mul ed eo
er:f32[16,16] = mul ee eo
es:f32[16,16] = mul ed ep
et:f32[16,16] = add_any er es
eu:f32[16] = mul dv ej
ev:f32[16] = mul dy ej
ew:f32[16] = mul dv en
ex:f32[16] = add_any ev ew
ey:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] eu
ez:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] ex
fa:f32[16,16] = mul ey bx
fb:f32[16,16] = mul ez bx
fc:f32[16,16] = mul ey ca
fd:f32[16,16] = add_any fb fc
fe:f32[16,16] <- bs[ce:ce+16,:]
ff:f32[16,16] <- bv[ce:ce+16,:]
fg:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] eq fe
fh:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] et fe
fi:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] eq ff
fj:f32[16,16] = add_any fh fi
fk:f32[16,16] = add fa fg
fl:f32[16,16] = add fd fj
in (cd, fk, cv, eh, fl, dq, ei) }
length=1
linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
num_carry=7
num_consts=6
reverse=False
unroll=1
] k bb l q bc r i z x y be bg bi
fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
in () }
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./scripts/kfac.py", line 168, in <module>
params, opt_state, stats = optimizer.step(
^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1339, in step
return self._step(params, state, rng, batch, func_state, learning_rate, momentum, damping)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
outs = jitted_func(instance, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
return method(instance, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1130, in _step
state = self._maybe_update_estimator_curvature(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 783, in _maybe_update_estimator_curvature
return self._maybe_update_estimator_state(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 735, in _maybe_update_estimator_state
state.estimator_state = lax.cond(
^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 755, in _update_estimator_curvature
state = self.estimator.update_curvature_matrix_estimate(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
return method(instance, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1422, in update_curvature_matrix_estimate
losses, losses_vjp = self._compute_losses_vjp(func_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
return method(instance, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1106, in _compute_losses_vjp
return self._vjp(func_args)
^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 456, in wrapped_transformation
return f(func_args, *args)
^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 871, in _layer_tag_vjp
_, aux_vjp, losses_inputs = jax.vjp(forward_aux, aux_dict, has_aux=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 833, in forward_aux
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 72, in eval_jaxpr_eqn
output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 257, in _pallas_call_jvp_rule
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
u:i32[] = add n 1
v:i32[] = mul n 16
w:f32[16,16] <- h[v:v+16,:]
x:f32[16,16] <- k[v:v+16,:]
y:f32[16,16] = transpose[permutation=(1, 0)] w
z:f32[16,16] = transpose[permutation=(1, 0)] x
ba:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] i y
bb:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] l y
bc:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] i z
bd:f32[16,16] = add_any bb bc
be:f32[16] = reduce_max[axes=(1,)] ba
bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
bg:bool[16,16] = eq ba bf
bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
bi:f32[16] = reduce_sum[axes=(1,)] bh
bj:f32[16,16] = mul bd bh
bk:f32[16] = reduce_sum[axes=(1,)] bj
bl:f32[16] = div bk bi
bm:f32[16] = max p be
bn:bool[16] = eq p bm
bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
bq:f32[16] = select_n bn bp bo
br:bool[16] = eq be bm
bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bu:f32[16] = select_n br bt bs
bv:f32[16] = div bq bu
bw:f32[16] = mul s bv
bx:bool[16] = eq be bm
by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
ca:f32[16] = select_n bx bz by
cb:bool[16] = eq p bm
cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
ce:f32[16] = select_n cb cd cc
cf:f32[16] = div ca ce
cg:f32[16] = mul bl cf
ch:f32[16] = add_any bw cg
ci:f32[16] = sub p bm
cj:f32[16] = sub s ch
ck:f32[16] = exp ci
cl:f32[16] = mul cj ck
cm:f32[16] = mul ck q
cn:f32[16] = mul cl q
co:f32[16] = mul ck t
cp:f32[16] = add_any cn co
cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
cs:f32[16,16] = sub ba cq
ct:f32[16,16] = sub bd cr
cu:f32[16,16] = exp cs
cv:f32[16,16] = mul ct cu
cw:f32[16] = reduce_sum[axes=(1,)] cu
cx:f32[16] = reduce_sum[axes=(1,)] cv
cy:f32[16] = add cm cw
cz:f32[16] = add cp cx
da:f32[16] = div 1.0 cy
db:f32[16] = neg cz
dc:f32[16] = mul db 1.0
dd:f32[16] = integer_pow[y=-2] cy
de:f32[16] = mul dc dd
df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
dh:f32[16,16] = mul cu df
di:f32[16,16] = mul cv df
dj:f32[16,16] = mul cu dg
dk:f32[16,16] = add_any di dj
dl:f32[16] = mul cm da
dm:f32[16] = mul cp da
dn:f32[16] = mul cm de
do:f32[16] = add_any dm dn
dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
dr:f32[16,16] = mul dp o
ds:f32[16,16] = mul dq o
dt:f32[16,16] = mul dp r
du:f32[16,16] = add_any ds dt
dv:f32[16,16] <- j[v:v+16,:]
dw:f32[16,16] <- m[v:v+16,:]
dx:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dh dv
dy:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dk dv
dz:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] dh dw
ea:f32[16,16] = add_any dy dz
eb:f32[16,16] = add dr dx
ec:f32[16,16] = add du ea
in (u, eb, bm, cy, ec, ch, cz) }
length=1
linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
num_carry=7
num_consts=6
reverse=False
unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep
Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
v:i32[] = program_id[axis=0]
w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
x:f32[16] = sub w b
y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
ba:i32[] = mul v e
bb:f32[16,16] <- j[ba:ba+16,:]
bc:f32[16,16] <- p[ba:ba+16,:]
bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
cd:i32[] = add bw 1
ce:i32[] = mul bw 16
cf:f32[16,16] <- bq[ce:ce+16,:]
cg:f32[16,16] <- bt[ce:ce+16,:]
ch:f32[16,16] = transpose[permutation=(1, 0)] cf
ci:f32[16,16] = transpose[permutation=(1, 0)] cg
cj:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] br ch
ck:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] bu ch
cl:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] br ci
cm:f32[16,16] = add_any ck cl
cn:f32[16] = reduce_max[axes=(1,)] cj
co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
cp:bool[16,16] = eq cj co
cq:f32[16,16] = convert_element_type[
new_dtype=float32
weak_type=False
] cp
cr:f32[16] = reduce_sum[axes=(1,)] cq
cs:f32[16,16] = mul cm cq
ct:f32[16] = reduce_sum[axes=(1,)] cs
cu:f32[16] = div ct cr
cv:f32[16] = max by cn
cw:bool[16] = eq by cv
cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
cz:f32[16] = select_n cw cy cx
da:bool[16] = eq cn cv
db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
dd:f32[16] = select_n da dc db
de:f32[16] = div cz dd
df:f32[16] = mul cb de
dg:bool[16] = eq cn cv
dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
dj:f32[16] = select_n dg di dh
dk:bool[16] = eq by cv
dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
dn:f32[16] = select_n dk dm dl
do:f32[16] = div dj dn
dp:f32[16] = mul cu do
dq:f32[16] = add_any df dp
dr:f32[16] = sub by cv
ds:f32[16] = sub cb dq
dt:f32[16] = exp dr
du:f32[16] = mul ds dt
dv:f32[16] = mul dt bz
dw:f32[16] = mul du bz
dx:f32[16] = mul dt cc
dy:f32[16] = add_any dw dx
dz:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] cv
ea:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] dq
eb:f32[16,16] = sub cj dz
ec:f32[16,16] = sub cm ea
ed:f32[16,16] = exp eb
ee:f32[16,16] = mul ec ed
ef:f32[16] = reduce_sum[axes=(1,)] ed
eg:f32[16] = reduce_sum[axes=(1,)] ee
eh:f32[16] = add dv ef
ei:f32[16] = add dy eg
ej:f32[16] = div 1.0 eh
ek:f32[16] = neg ei
el:f32[16] = mul ek 1.0
em:f32[16] = integer_pow[y=-2] eh
en:f32[16] = mul el em
eo:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] ej
ep:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] en
eq:f32[16,16] = mul ed eo
er:f32[16,16] = mul ee eo
es:f32[16,16] = mul ed ep
et:f32[16,16] = add_any er es
eu:f32[16] = mul dv ej
ev:f32[16] = mul dy ej
ew:f32[16] = mul dv en
ex:f32[16] = add_any ev ew
ey:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] eu
ez:f32[16,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(16, 1)
] ex
fa:f32[16,16] = mul ey bx
fb:f32[16,16] = mul ez bx
fc:f32[16,16] = mul ey ca
fd:f32[16,16] = add_any fb fc
fe:f32[16,16] <- bs[ce:ce+16,:]
ff:f32[16,16] <- bv[ce:ce+16,:]
fg:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] eq fe
fh:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] et fe
fi:f32[16,16] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] eq ff
fj:f32[16,16] = add_any fh fi
fk:f32[16,16] = add fa fg
fl:f32[16,16] = add fd fj
in (cd, fk, cv, eh, fl, dq, ei) }
length=1
linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
num_carry=7
num_consts=6
reverse=False
unroll=1
] k bb l q bc r i z x y be bg bi
fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
in () }
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I would really appreciate your advice on this task. Specifically