I tried
def AssertShape(x: jnp.array, shape) -> None:
if not jnp.array_equal(x.shape, shape):
raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}')
and got
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
(BTW, note the double period)