Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

all_gather in pmap fails with new-style random keys, works with old-style #23647

Open
Lookatator opened this issue Sep 15, 2024 · 2 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@Lookatator
Copy link

Description

Description

I am encountering issues when using jax.lax.all_gather on a pytree containing random keys instantiated via jax.random.key.
Interestingly, this does not happen when using old-style random keys jax.random.PRNGKey.

Minimal Reproducible Example

import jax
import jax.numpy as jnp

def fn(x):
    return jax.lax.all_gather(x, axis_name="p")

devices = jax.devices("cpu")

key = jax.random.key(0)
subkeys = jax.random.split(key, num=2)
subkeys = jnp.stack(subkeys)
subkeys = jnp.expand_dims(subkeys, axis=0)

jax.pmap(fn, axis_name="p", devices=devices)(jax.random.key_data(subkeys))  # Runs fine

jax.pmap(fn, axis_name="p", devices=devices)(subkeys)  # Error

And here is the error message I get:

Traceback
Traceback (most recent call last):
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1079, in lower_jaxpr_to_module
    if not ctx.module.operation.verify():
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: "pmap(fn)/jit(main)/all_gather[all_gather_dimension=0 axis_name=p axis_index_groups=None axis_size=1 tiled=False]"(callsite("fn"("/Users/lg4615/projects/QDax/test.py":6:0) at "<module>"("/Users/lg4615/projects/QDax/test.py":18:0))): broadcast_dimensions size (1) does not match operand rank (2)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/lg4615/projects/QDax/test.py", line 18, in <module>
    jax.pmap(fn, axis_name="p", devices=devices)(subkeys)
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 299, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
    ans = call(fun, *args)
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 588, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 826, in lower_parallel_callable
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/Users/lg4615/projects/QDax/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1092, in lower_jaxpr_to_module
    raise ValueError("\n".join(msg_lines) + "\n" +
ValueError: Cannot lower jaxpr with verifier errors:
	broadcast_dimensions size (1) does not match operand rank (2)
		at loc("pmap(fn)/jit(main)/all_gather[all_gather_dimension=0 axis_name=p axis_index_groups=None axis_size=1 tiled=False]"(callsite("fn"("/Users/lg4615/projects/QDax/test.py":6:0) at "<module>"("/Users/lg4615/projects/QDax/test.py":18:0))))
Define JAX_DUMP_IR_TO to dump the module.

Questions

  • Is this a bug in JAX, or am I missing something in how pmap handles PRNGKey arrays?
  • if this is not a bug, would it be relevant to improve the error message?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.13 (main, Aug 26 2024, 14:04:33) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='IC-XQG42K24NV', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:04 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6020', machine='arm64')
@Lookatator Lookatator added the bug Something isn't working label Sep 15, 2024
@justinjfu justinjfu self-assigned this Sep 16, 2024
@justinjfu
Copy link
Collaborator

justinjfu commented Sep 19, 2024

This is most likely a bug in JAX. JAX has an extended typing system where keys are a "logical" type that are backed by a "physical" type (uint32 seeds). During MLIR lowering we convert the PRNG keys into the underlying uint32 array, but we also need to update the axes/dimension parameters on the operations which seems to be wrong here.

We're currently overhauling this system so we will fix this bug in the process but it won't be an immediate fix.

@Lookatator
Copy link
Author

Oh I see! Thank you very much for the info, and best of luck for the future changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants