all_gather
in pmap
fails with new-style random keys, works with old-style
#23647
Labels
bug
Something isn't working
Description
Description
I am encountering issues when using
jax.lax.all_gather
on a pytree containing random keys instantiated viajax.random.key
.Interestingly, this does not happen when using old-style random keys
jax.random.PRNGKey
.Minimal Reproducible Example
And here is the error message I get:
Traceback
Questions
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: