Skip to content

[JAX] Extend permutation primitive to multi-GPU #2536

@tdophung

Description

@tdophung

Current JAX primitive for permutation does not have any partitioning or shardy rule. Needs to define them to divide up the work to multiple GPUs.

The approach likely to be taken is to shard this along the B (batch axis) and potentially M (embedding/token hidden size) axis, though less likely. Needs to consider how the data should be sharded in previous operations like router and next operations like FC1, FC2 to avoid re-sharding or too many communications.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions