Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimSteps not compatible with shard_map due to lax.cond #23693

Open
Rocamonde opened this issue Sep 17, 2024 · 2 comments
Open

OptimSteps not compatible with shard_map due to lax.cond #23693

Rocamonde opened this issue Sep 17, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@Rocamonde
Copy link

Rocamonde commented Sep 17, 2024

Description

When trying to use optax.MultiSteps on a data-parallel setup with shard_map, I am getting the following error:

NotImplementedError: No replication rule for cond. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues

A simple example can be found below:

optimizer = optax.MultiSteps(optax.adamw(learning_rate), every_k_schedule=10)

...

@partial(shard_map, mesh=mesh, in_specs=(P(), P(), P("data")), out_specs=(P(), P()))
def make_step(key: jax.Array, state: State, data: Batch):
    loss, grads = jax.value_and_grad(loss_fn)(state["params"], data, key)
    loss = jax.lax.pmean(loss, "data")
    grads = jax.lax.pmean(grads, "data")
    updates, new_opt_state = optim.update(grads, state["opt_state"], state["params"])
    new_params = optax.apply_updates(state["params"], updates)
    return dict(params=new_params, opt_state=new_opt_state)

...

for step, batch in enumerate(data_loader):
    key, subkey = jax.random.split(key)
    state = make_step(subkey, state, batch)

Disabling multi-step makes the error go away.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='192-222-54-171', release='6.2.0-37-generic', version='#38~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov  2 18:01:13 UTC 2', machine='x86_64')


$ nvidia-smi
Tue Sep 17 13:08:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:61:00.0 Off |                    0 |
| N/A   30C    P0             116W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:62:00.0 Off |                    0 |
| N/A   31C    P0             115W / 700W |    541MiB / 81559MiB |      1%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:63:00.0 Off |                    0 |
| N/A   31C    P0             112W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:64:00.0 Off |                    0 |
| N/A   30C    P0             118W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:6A:00.0 Off |                    0 |
| N/A   30C    P0             121W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:6B:00.0 Off |                    0 |
| N/A   32C    P0             110W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:6C:00.0 Off |                    0 |
| N/A   31C    P0             111W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:6D:00.0 Off |                    0 |
| N/A   30C    P0             116W / 700W |    541MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    161185      C   python                                      528MiB |
|    1   N/A  N/A    161185      C   python                                      528MiB |
|    2   N/A  N/A    161185      C   python                                      528MiB |
|    3   N/A  N/A    161185      C   python                                      528MiB |
|    4   N/A  N/A    161185      C   python                                      528MiB |
|    5   N/A  N/A    161185      C   python                                      528MiB |
|    6   N/A  N/A    161185      C   python                                      528MiB |
|    7   N/A  N/A    161185      C   python                                      528MiB |
+---------------------------------------------------------------------------------------+
@Rocamonde Rocamonde added the bug Something isn't working label Sep 17, 2024
@sharadmv
Copy link
Member

@mattjj can we write a general replication rule for cond?

@mattjj mattjj self-assigned this Sep 18, 2024
@mattjj
Copy link
Member

mattjj commented Sep 18, 2024

Yeah, should be easy! (In the sense of: "We choose to go to the Moon implement cond rules in this decade, not because they are easy, but because they are hard we thought they would be easy")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants