Skip to content

Comments

fix reshape construction#941

Draft
ftynse wants to merge 1 commit intomainfrom
users/ftynse/fix-reshape
Draft

fix reshape construction#941
ftynse wants to merge 1 commit intomainfrom
users/ftynse/fix-reshape

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Feb 20, 2026

Signed-off-by: Alex Zinenko git@ozinenko.com

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts how Reshape ops are created when chaining MMA operations with differing vector_shapes, aiming to correct reshape construction for operand layout transitions.

Changes:

  • Updates add_reshape_if_needed to pass different vector_shapes dictionaries into Reshape and to assign custom_reshape.vector_shapes differently.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

arg = get_custom(arg)
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph(
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reshape is being constructed with args=arg.fx_node (a single fx.Node). Downstream codegen (handle_reshape) treats args as a sequence (calls len(args) and indexes args[0]), so passing a single node here will raise at runtime. Wrap the argument in a 1-element sequence (e.g., [mma.lhs] / [mma.rhs] or [arg.fx_node]) so args is always a list/tuple of nodes.

Suggested change
reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph(
reshape = Reshape([arg.fx_node], mma.vector_shapes).add_to_graph(

Copilot uses AI. Check for mistakes.
Comment on lines +162 to +166
reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph(
mma.graph, loc=mma.location
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = mma.vector_shapes
custom_reshape.vector_shapes = prev_mma.vector_shapes
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Reshape construction appears to swap the meaning of target_vector_shape vs vector_shapes compared to other reshape sites in the codebase. Elsewhere (e.g. wave_lang/kernel/wave/decompose_vmma_ops.py:136-153), target_vector_shape is set to the source (pre-reshape) vector shape and reshape.vector_shapes is set to the destination (post-reshape) vector shape so expansion/fixup can derive num_slices correctly. Here target_vector_shape is set to mma.vector_shapes and custom_reshape.vector_shapes to prev_mma.vector_shapes, which would make the reshape compute slice/concat factors in the wrong direction when chaining MMAs. Consider restoring the prior direction: set target_vector_shape from prev_mma.vector_shapes and set custom_reshape.vector_shapes to mma.vector_shapes (the layout the current MMA consumes).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant