From 12814b19aa66e019e190e1be25fdceaa29ec5038 Mon Sep 17 00:00:00 2001 From: RihaanBH-1810 Date: Mon, 1 Sep 2025 21:28:21 +0530 Subject: [PATCH] Logic to insert copy node on last use, not applicable to aliases tho --- .../insert_write_back_for_buffers_pass.py | 68 ++++++++++++++----- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/exir/passes/insert_write_back_for_buffers_pass.py b/exir/passes/insert_write_back_for_buffers_pass.py index 4dce40ae57c..d90371659be 100644 --- a/exir/passes/insert_write_back_for_buffers_pass.py +++ b/exir/passes/insert_write_back_for_buffers_pass.py @@ -26,21 +26,27 @@ def _insert_copy( mutated_outputs: List[Optional[str]], input_name_to_node: Dict[str, torch.fx.Node], ): - """ - Find the all the buffers and inputs that were mutated and insert copy_ - operators to reflect mutations. - """ - output_node = gm.graph.output_node() + + if all(name is None for name in mutated_outputs): # if all mutated outputs are none, do not execute + return [] + + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break assert output_node is not None + outputs = pytree.tree_flatten(output_node.args)[0] assert len(outputs) == len(mutated_outputs) - user_output_nodes = [] - buffer_output_nodes = [] + user_returns = [] + buffer_copies = [] + for return_node, mutated_node_name in zip(outputs, mutated_outputs): - # User output, leave alone + if mutated_node_name is None: - user_output_nodes.append(return_node) + user_returns.append(return_node) continue # Mutable buffer grab the node @@ -50,22 +56,48 @@ def _insert_copy( raise RuntimeError( f"Could not find {mutated_node_name} in either buffer or input nodes" ) + + # needed to rebuild node index pairs after graph modification + nodes = list(gm.graph.nodes) + node_index = {node: i for i, node in enumerate(nodes)} + + # last reader + last_read_idx = node_index[mutated_node] - # insert copy - with gm.graph.inserting_before(output_node): + for user in mutated_node.users: + if user.op != "output": + last_read_idx = max(last_read_idx, node_index[user]) + + + # last consumer + last_use_idx = node_index[return_node] + + for user in return_node.users: + if user.op != "output": + last_use_idx = max(last_use_idx, node_index[user]) + + insert_after = max(last_read_idx, last_use_idx) + + if insert_after + 1 < len(nodes): + insert_point = (nodes[insert_after + 1]) + else: + insert_point = (output_node) + + + with gm.graph.inserting_before(insert_point): buffer_output = gm.graph.call_function( torch.ops.aten.copy_.default, (mutated_node, return_node) ) - # add output of copy to graph outputs - buffer_output_nodes.append(buffer_output) + buffer_copies.append(buffer_output) + # Re‑wire graph output so that the copy results precede user returns with gm.graph.inserting_before(output_node): - buffer_output_nodes.extend(user_output_nodes) - # Remove old outputs - new_output = gm.graph.output(tuple(buffer_output_nodes)) - output_node.replace_all_uses_with(new_output) + buffer_copies.extend(user_returns) + new_out = gm.graph.output(tuple(buffer_copies)) + output_node.replace_all_uses_with(new_out) gm.graph.erase_node(output_node) - return buffer_output_nodes + + return buffer_copies def _is_inplace_node(node: torch.fx.Node) -> bool: