-
Notifications
You must be signed in to change notification settings - Fork 496
Open
Labels
Description
Bug description
Summary
Adding logging to torchtitan expert_parallel wrapper and running fp8 rowwise MoE training (where we set token group alignment size to 16), I see the alignment size is set correctly but the resulting M dimension is not divisible by 16:
[rank0]:[titan] 2025-08-28 10:05:31,175 - root - INFO - TOKEN_GROUP_ALIGN_SIZE_M = 16
[rank0]:[titan] 2025-08-28 10:05:31,454 - root - INFO - input_shape = torch.Size([16333, 5120])
This causes an error in scaled_grouped_mm, which expects the contracting dimension (which is M for the the gemm grad_weight = grad_output_t @ input
):
RuntimeError: strides should be multiple of 16 bytes
I see in git blame the last PR that touched this code was #1561
cc @tianyu-l
Versions
- torchtitan with latest main branch
- torchao latest main branch
tianyu-l