You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Current JAX primitive for permutation does not have any partitioning or shardy rule. Needs to define them to divide up the work to multiple GPUs.
The approach likely to be taken is to shard this along the B (batch axis) and potentially M (embedding/token hidden size) axis, though less likely. Needs to consider how the data should be sharded in previous operations like router and next operations like FC1, FC2 to avoid re-sharding or too many communications.