-
Notifications
You must be signed in to change notification settings - Fork 2.7k
-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OptimSteps not compatible with shard_map due to lax.cond #23693
Labels
bug
Something isn't working
Comments
@mattjj can we write a general replication rule for cond? |
Yeah, should be easy! (In the sense of: "We choose to |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
When trying to use optax.MultiSteps on a data-parallel setup with shard_map, I am getting the following error:
A simple example can be found below:
Disabling multi-step makes the error go away.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: