Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions exir/passes/insert_write_back_for_buffers_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,27 @@ def _insert_copy(
mutated_outputs: List[Optional[str]],
input_name_to_node: Dict[str, torch.fx.Node],
):
"""
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = gm.graph.output_node()

if all(name is None for name in mutated_outputs): # if all mutated outputs are none, do not execute
return []

output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: go back to gm.graph.output_node()

assert output_node is not None

outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)

user_output_nodes = []
buffer_output_nodes = []
user_returns = []
buffer_copies = []

for return_node, mutated_node_name in zip(outputs, mutated_outputs):
# User output, leave alone

if mutated_node_name is None:
user_output_nodes.append(return_node)
user_returns.append(return_node)
continue

# Mutable buffer grab the node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a stupid question since I think I wrote this code, but looking back on it. Are mutated node and return node not the exact same node?

Expand All @@ -50,22 +56,48 @@ def _insert_copy(
raise RuntimeError(
f"Could not find {mutated_node_name} in either buffer or input nodes"
)

# needed to rebuild node index pairs after graph modification
nodes = list(gm.graph.nodes)
node_index = {node: i for i, node in enumerate(nodes)}

# last reader
last_read_idx = node_index[mutated_node]

# insert copy
with gm.graph.inserting_before(output_node):
for user in mutated_node.users:
if user.op != "output":
last_read_idx = max(last_read_idx, node_index[user])


# last consumer
last_use_idx = node_index[return_node]

for user in return_node.users:
if user.op != "output":
last_use_idx = max(last_use_idx, node_index[user])

insert_after = max(last_read_idx, last_use_idx)

if insert_after + 1 < len(nodes):
insert_point = (nodes[insert_after + 1])
else:
insert_point = (output_node)


with gm.graph.inserting_before(insert_point):
buffer_output = gm.graph.call_function(
torch.ops.aten.copy_.default, (mutated_node, return_node)
)
# add output of copy to graph outputs
buffer_output_nodes.append(buffer_output)
buffer_copies.append(buffer_output)

# Re‑wire graph output so that the copy results precede user returns
with gm.graph.inserting_before(output_node):
buffer_output_nodes.extend(user_output_nodes)
# Remove old outputs
new_output = gm.graph.output(tuple(buffer_output_nodes))
output_node.replace_all_uses_with(new_output)
buffer_copies.extend(user_returns)
new_out = gm.graph.output(tuple(buffer_copies))
output_node.replace_all_uses_with(new_out)
gm.graph.erase_node(output_node)
return buffer_output_nodes

return buffer_copies


def _is_inplace_node(node: torch.fx.Node) -> bool:
Expand Down