Conversation
Signed-off-by: Alex Zinenko <git@ozinenko.com>
There was a problem hiding this comment.
Pull request overview
This PR adjusts how Reshape ops are created when chaining MMA operations with differing vector_shapes, aiming to correct reshape construction for operand layout transitions.
Changes:
- Updates
add_reshape_if_neededto pass differentvector_shapesdictionaries intoReshapeand to assigncustom_reshape.vector_shapesdifferently.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| arg = get_custom(arg) | ||
| if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): | ||
| reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( | ||
| reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph( |
There was a problem hiding this comment.
Reshape is being constructed with args=arg.fx_node (a single fx.Node). Downstream codegen (handle_reshape) treats args as a sequence (calls len(args) and indexes args[0]), so passing a single node here will raise at runtime. Wrap the argument in a 1-element sequence (e.g., [mma.lhs] / [mma.rhs] or [arg.fx_node]) so args is always a list/tuple of nodes.
| reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph( | |
| reshape = Reshape([arg.fx_node], mma.vector_shapes).add_to_graph( |
| reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph( | ||
| mma.graph, loc=mma.location | ||
| ) | ||
| custom_reshape = get_custom(reshape) | ||
| custom_reshape.vector_shapes = mma.vector_shapes | ||
| custom_reshape.vector_shapes = prev_mma.vector_shapes |
There was a problem hiding this comment.
The Reshape construction appears to swap the meaning of target_vector_shape vs vector_shapes compared to other reshape sites in the codebase. Elsewhere (e.g. wave_lang/kernel/wave/decompose_vmma_ops.py:136-153), target_vector_shape is set to the source (pre-reshape) vector shape and reshape.vector_shapes is set to the destination (post-reshape) vector shape so expansion/fixup can derive num_slices correctly. Here target_vector_shape is set to mma.vector_shapes and custom_reshape.vector_shapes to prev_mma.vector_shapes, which would make the reshape compute slice/concat factors in the wrong direction when chaining MMAs. Consider restoring the prior direction: set target_vector_shape from prev_mma.vector_shapes and set custom_reshape.vector_shapes to mma.vector_shapes (the layout the current MMA consumes).
Signed-off-by: Alex Zinenko git@ozinenko.com