Conversation
Relax constraints on MMA operation to support leading batch dimensions. These appear in practice and are relatively easy to support. No effect on lowering since it happens after vector types are introduced, the batch dimensions are expected to be tiled/unrolled by preliminary passes. Signed-off-by: Alex Zinenko <git@ozinenko.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR relaxes wave.mma shape constraints to support leading batch dimensions (batched MMA), and adds tests to validate verification, index-expression inference, and elements-per-thread propagation for the batched case.
Changes:
- Update
MmaOpindex-expression initialization/propagation to treat the last two dims asM,Nand allow additional leading batch dims. - Update
MmaOpverifier to accept ranks > 2 with leading batch dims and adjust invalid-case coverage accordingly. - Add new MLIR tests covering batched MMA in ops, index-expr inference, and elements-per-thread propagation.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Extends MMA index-expr initialization/verification logic for batched (rank>2) tensors. |
| water/test/Dialect/Wave/ops.mlir | Adds a @batched_mma presence/syntax test for wave.mma with a leading batch dim. |
| water/test/Dialect/Wave/ops-invalid.mlir | Updates invalid MMA test to reflect new behavior (trailing “batch” now triggers a dimension-mismatch error). |
| water/test/Dialect/Wave/infer-index-exprs.mlir | Adds index-expression inference coverage for batched MMA (default batch mapping + standard M/N/K). |
| water/test/Dialect/Wave/propagate-elements-per-thread.mlir | Adds elements-per-thread propagation coverage for batched MMA producing vector types. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- make sure thread-independent constraints are mixed in for batch dimensions + test - add a verifier that mma operands are at least 2d Signed-off-by: Alex Zinenko <git@ozinenko.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Relax constraints on MMA operation to support leading batch dimensions. These appear in practice and are relatively easy to support. No effect on lowering since it happens after vector types are introduced, the batch dimensions are expected to be tiled/unrolled by preliminary passes.