I am trying to use some Jax code in a Pallas kernel but for some reason my code does not work anymore.
import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array
from jax.experimental import sparse
key = jax.random.PRNGKey(52)
other = jax.random.normal(key, (10, 10))
diags = jax.random.normal(key, (3, 10))
offsets = (-2, 1, 2)
def dia_matmul_kernel(diags_ref, offsets, other_ref, o_ref):
diags, other = diags_ref[...], other_ref[...]
N = other.shape[0]
out = jnp.zeros((N, N))
print(offsets)
for offset, diag in zip(offsets, diags):
start = jax.lax.max(0, offset)
end = min(N, N + offset)
top = max(0, -offset)
bottom = top + end - start
out = out.at[top:bottom, :].add(
diag[start:end, None] * other[start:end, :]
)
o_ref[...] = out
@functools.partial(jax.jit, static_argnums=(1, ))
def dia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Array:
return pl.pallas_call(
dia_matmul_kernel,
out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
)(diags, offsets ,other)
dia_matmul(diags, offsets,other)
I understand that is not best practice to print stuff in a Jax JIT function but when I print my offsets, which should be kept static from the static_argnums=(1,), it says:
(Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>)
I don't understand why that is the case, I'm new to Jax and Pallas so I'm not yet fully confident with this whole Tracing concept. Also the last operation of the for loop with the out is not working so if anyone also has an idea :D
Many thanks!
The arguments passed to
pallas_callwill always be traced, regardless of whether or not they are static before being passed topallas_call. This is true any time you pass arguments to a JAX primitive, or a transformed function: all inputs will be traced unless explicitly marked as static in the function you are calling.pallas_calldoesn't currently have any way of marking static arguments, and will trace all arguments passed to the wrapped function. If you want some arguments to be static, you should be able to do this by closing over them in the function that you pass topallas_call: