[water] add permutation mapping support to read/write#931
[water] add permutation mapping support to read/write#931
Conversation
Introduce support for mappings in read/write operations. Mapping indicates how value tensor dimensions map to memory tensor dimensions when this is not an identity mapping. We only support permutation mappings for now and will very progressively relax on the per-need basis to avoid importing the exceissive expressiveness of pywave in this case. No lowering is currently added and the mapping is intended for roundtripping purposes. Lowering will need to be added separately after investigating various particularities of how remapped index expressions are combined with masking on python side. Update the type inference to account for mapping. This required relaxing WaveExprListAttr to support maps with dimensions as we want to express permutations without introducing additional naming qualities. Using dimensions also allows us to rely on existing affine permutation logic. While there, a significant simplification was implemented to parsing of expression lists: we no longer use a quadratic complexity replacement logic when a single-pass algorithm exists, and we no longer construct strings (dynamic allocation) or affine maps (arena allocation under lock) for temporary objects. Signed-off-by: Alex Zinenko <git@ozinenko.com>
|
@ftynse TODO: needs python translation tests |
There was a problem hiding this comment.
Pull request overview
This PR adds support for permutation mappings in read/write operations for the Wave dialect. The mapping attribute indicates how value tensor dimensions map to memory tensor dimensions when this is not an identity mapping. The implementation includes type inference updates to account for mappings, relaxation of WaveExprListAttr to support dimension parameters, and optimized parsing logic that avoids quadratic complexity.
Changes:
- Added
mappingattribute toReadOpandWriteOpto support permutation-based dimension remapping - Updated type inference in
WaveOps.cppto propagate shapes correctly through mappings - Refactored
WaveAttrs.cppexpression parsing to use single-pass algorithm instead of quadratic string replacement - Extended
WaveExprListAttrto support affine dimensions in addition to symbols
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| wave_lang/kernel/wave/mlir_converter/water_emitter.py | Adds _convert_index_mapping_to_water function to convert Python IndexMapping to MLIR attributes and uses it in Read/Write emission |
| wave_lang/kernel/wave/mlir_converter/attr_type_converter.py | Extends preprocess_symbols to handle dimension names in mappings with for_mapping flag |
| water/test/Dialect/Wave/ops-invalid.mlir | Updates error messages to reflect new shape validation logic that considers mappings |
| water/test/Dialect/Wave/infer-types.mlir | Adds tests for forward and backward type inference with permutation mappings |
| water/test/Dialect/Wave/attr-type.mlir | Adds comprehensive tests for dimension support in expression lists |
| water/test/Dialect/Wave/attr-type-invalid.mlir | Adds validation tests for reserved dimension names and conflicts |
| water/python/WaterExtensionNanobind.cpp | Removes restriction that prevented dimensions in WaveExprListAttr |
| water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp | Adds early exit for operations with mappings (not yet supported in lowering) |
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Implements mapping-aware type inference and verification for Read/Write operations |
| water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | Adds overload for propagateShapeInformation that accepts shape as array of symbols |
| water/lib/Dialect/Wave/IR/WaveAttrs.cpp | Refactors expression parsing to support dimensions and use single-pass printing algorithm |
| water/include/water/Dialect/Wave/IR/WaveOps.td | Adds optional mapping attribute to ReadOp and WriteOp definitions |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.h | Declares new shape propagation function signature |
| docs/wave/ir_design_notes.rst | Documents IndexMapping design and semantics for read/write operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Then push the replacement. | ||
| if (exprStr[i] == 's') | ||
| os << symbolNames[position]; | ||
| else | ||
| os << dimNames[position]; |
There was a problem hiding this comment.
Potential out-of-bounds access. If position is greater than or equal to the size of symbolNames or dimNames, this will cause undefined behavior. Consider adding bounds checking before accessing the arrays to prevent crashes on malformed input.
| // Then push the replacement. | |
| if (exprStr[i] == 's') | |
| os << symbolNames[position]; | |
| else | |
| os << dimNames[position]; | |
| // Capture the original token (e.g., "s3" or "d1") for fallback printing. | |
| StringRef token = StringRef(exprStr).slice(i, end); | |
| // Then push the replacement, if the index is within bounds. | |
| if (exprStr[i] == 's') { | |
| if (position >= 0 && | |
| static_cast<size_t>(position) < symbolNames.size()) | |
| os << symbolNames[position]; | |
| else | |
| os << token; | |
| } else { | |
| if (position >= 0 && | |
| static_cast<size_t>(position) < dimNames.size()) | |
| os << dimNames[position]; | |
| else | |
| os << token; | |
| } |
|
|
||
| Currently only supports permutation mappings. The `is_read` flag indicates | ||
| whether to use the input mapping (True) or output mapping (False). This is | ||
| due to excessively expressive design choice in pywave: the unused mapping |
There was a problem hiding this comment.
Spelling error in comment: 'excessively' is misspelled as 'exceissive'. Should be 'excessive expressiveness' instead of 'exceissive expressiveness'.
| due to excessively expressive design choice in pywave: the unused mapping | |
| due to excessive expressiveness in pywave: the unused mapping |
| # position in that list to find numeric permutation indices | ||
| # without parsing symbol names to extract the iterator position. | ||
| # If we had a proper data structure instead of blindly relying | ||
| # on sympy symbols everywhere, this would have been a easy to |
There was a problem hiding this comment.
Grammatical error in comment: 'a easy' should be 'an easy'.
| # on sympy symbols everywhere, this would have been a easy to | |
| # on sympy symbols everywhere, this would have been an easy to |
| Placeholder, | ||
| SelectOp, | ||
| Read, | ||
| SelfIndex, |
There was a problem hiding this comment.
SelectOp is missing from the imports. When the import of SelectOp from wave_lang.kernel.ops.wave_ops was removed in the change at line 78, SelectOp should have been added to the imports from water_mlir.water_mlir.dialects.wave (in a non-changed region starting at line 105). SelectOp is still used in WAVE_OP_CONSTRUCTORS at line 178, which will cause a NameError at runtime.
| SelfIndex, | |
| SelfIndex, | |
| SelectOp, |
| "Memory type shape is required for index mapping conversion." | ||
| ) | ||
|
|
||
| filtered_shape: Sequence[IndexSymbol] = [] |
There was a problem hiding this comment.
The type annotation for filtered_shape should be list[IndexSymbol] instead of Sequence[IndexSymbol] since it's initialized as a mutable list and items are appended to it. Sequence is an immutable protocol that doesn't support append operations.
| filtered_shape: Sequence[IndexSymbol] = [] | |
| filtered_shape: list[IndexSymbol] = [] |
|
@copilot can you add tests for read and write operation with |
- [x] Update tests to use 3D tensors with cyclic permutation (d0,d1,d2) -> (d1,d2,d0) - [x] Use non-self-inverse permutation for more robust testing - [x] Apply multi-line formatting to dictionary arguments for readability - [x] Reply to formatting comment <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com>
Introduce support for mappings in read/write operations. Mapping
indicates how value tensor dimensions map to memory tensor dimensions
when this is not an identity mapping. We only support permutation
mappings for now and will very progressively relax on the per-need basis
to avoid importing the exceissive expressiveness of pywave in this case.
No lowering is currently added and the mapping is intended for
roundtripping purposes. Lowering will need to be added separately after
investigating various particularities of how remapped index expressions
are combined with masking on python side.
Update the type inference to account for mapping.
This required relaxing WaveExprListAttr to support maps with dimensions
as we want to express permutations without introducing additional naming
qualities. Using dimensions also allows us to rely on existing affine
permutation logic.
While there, a significant simplification was implemented to parsing of
expression lists: we no longer use a quadratic complexity replacement
logic when a single-pass algorithm exists, and we no longer construct
strings (dynamic allocation) or affine maps (arena allocation under
lock) for temporary objects.
Fixes #918