Skip to content

EP: token alignment not working as expected #1651

@danielvegamyhre

Description

@danielvegamyhre

Bug description

Summary

Adding logging to torchtitan expert_parallel wrapper and running fp8 rowwise MoE training (where we set token group alignment size to 16), I see the alignment size is set correctly but the resulting M dimension is not divisible by 16:


[rank0]:[titan] 2025-08-28 10:05:31,175 - root - INFO - TOKEN_GROUP_ALIGN_SIZE_M = 16
[rank0]:[titan] 2025-08-28 10:05:31,454 - root - INFO - input_shape = torch.Size([16333, 5120])

This causes an error in scaled_grouped_mm, which expects the contracting dimension (which is M for the the gemm grad_weight = grad_output_t @ input):

  RuntimeError: strides should be multiple of 16 bytes

I see in git blame the last PR that touched this code was #1561

cc @tianyu-l

Versions

  • torchtitan with latest main branch
  • torchao latest main branch

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions