Skip to content

Comments

[water] support batched MMA#927

Merged
ftynse merged 2 commits intomainfrom
users/ftynse/mma
Feb 20, 2026
Merged

[water] support batched MMA#927
ftynse merged 2 commits intomainfrom
users/ftynse/mma

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Feb 19, 2026

Relax constraints on MMA operation to support leading batch dimensions. These appear in practice and are relatively easy to support. No effect on lowering since it happens after vector types are introduced, the batch dimensions are expected to be tiled/unrolled by preliminary passes.

Relax constraints on MMA operation to support leading batch dimensions.
These appear in practice and are relatively easy to support. No effect
on lowering since it happens after vector types are introduced, the
batch dimensions are expected to be tiled/unrolled by preliminary
passes.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR relaxes wave.mma shape constraints to support leading batch dimensions (batched MMA), and adds tests to validate verification, index-expression inference, and elements-per-thread propagation for the batched case.

Changes:

  • Update MmaOp index-expression initialization/propagation to treat the last two dims as M,N and allow additional leading batch dims.
  • Update MmaOp verifier to accept ranks > 2 with leading batch dims and adjust invalid-case coverage accordingly.
  • Add new MLIR tests covering batched MMA in ops, index-expr inference, and elements-per-thread propagation.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
water/lib/Dialect/Wave/IR/WaveOps.cpp Extends MMA index-expr initialization/verification logic for batched (rank>2) tensors.
water/test/Dialect/Wave/ops.mlir Adds a @batched_mma presence/syntax test for wave.mma with a leading batch dim.
water/test/Dialect/Wave/ops-invalid.mlir Updates invalid MMA test to reflect new behavior (trailing “batch” now triggers a dimension-mismatch error).
water/test/Dialect/Wave/infer-index-exprs.mlir Adds index-expression inference coverage for batched MMA (default batch mapping + standard M/N/K).
water/test/Dialect/Wave/propagate-elements-per-thread.mlir Adds elements-per-thread propagation coverage for batched MMA producing vector types.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only nits

- make sure thread-independent constraints are mixed in for batch dimensions + test
- add a verifier that mma operands are at least 2d

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse merged commit f5a6e8c into main Feb 20, 2026
15 checks passed
@ftynse ftynse deleted the users/ftynse/mma branch February 20, 2026 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants