Skip to content

Commit 7eb50c7

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
let backends set info on the lowered_backend.meta from preprocess (pytorch#13856)
Summary: We want some delegates to be able to test experimental features that interact with the core runtime. Typically when doing this there is an expectation of tight coupling in the full lowering process so we arent super concerned about UX. The backend api is pretty hard to change because the code is mostly owned by 3rd parties, so we want to be really sure when we change it that its correct. By exposing this module.meta hookup we can test things in a controled maner without changing the blessed api before we are ready. Reviewed By: hsharma35 Differential Revision: D81466391
1 parent 176800e commit 7eb50c7

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

exir/backend/backend_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def to_backend(
126126
lowered_module.meta = {
127127
"debug_handle_map": preprocess_result.debug_handle_map
128128
}
129+
if preprocess_result.delegate_info is not None:
130+
lowered_module.meta["delegate_info"] = preprocess_result.delegate_info
129131
return lowered_module
130132
raise NotImplementedError(f"Backend {backend_id} was not found.")
131133

@@ -610,6 +612,8 @@ def lower_all_submodules_to_backend(
610612
lowered_module.meta = {
611613
"debug_handle_map": preprocess_result.debug_handle_map,
612614
}
615+
if preprocess_result.delegate_info is not None:
616+
lowered_module.meta["delegate_info"] = preprocess_result.delegate_info
613617
is_submodule = call_submodule_node.meta["is_submodule"]
614618
toplevel_input_specs_to_delete = call_submodule_node.meta[
615619
"toplevel_input_specs_to_delete"

exir/backend/backend_details.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99

10-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Optional, Tuple, Union
1111

1212
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1313

@@ -32,6 +32,9 @@ class PreprocessResult:
3232
# but retrieveable by delegates via the NamedDataMap at runtime.
3333
data_store_output: Optional[NamedDataStoreOutput] = None
3434

35+
# Optional delegate-specific information that will be added to the lowered_module's metadata.
36+
delegate_info: Optional[Any] = None
37+
3538

3639
"""
3740
How to create a backend (for example, BackendWithCompilerDemo):

exir/backend/test/backend_with_compiler_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,5 @@ def preprocess(
138138
encoding="utf8",
139139
),
140140
debug_handle_map=debug_handle_map,
141+
delegate_info="test",
141142
)

exir/backend/test/test_backends_lifted.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):
12641264

12651265
gm = to_edge(export(ComposedM(), inputs, strict=True))
12661266
gm.exported_program().module()(*inputs)
1267+
1268+
def test_delegate_info_full_delegate(self):
1269+
"""
1270+
Test that delegate_info from BackendWithCompilerDemo ends up in the call_delegate node metadata
1271+
when using full delegation (to_backend directly).
1272+
"""
1273+
1274+
class SinModule(torch.nn.Module):
1275+
def __init__(self):
1276+
super().__init__()
1277+
1278+
def forward(self, x):
1279+
return torch.sin(x)
1280+
1281+
sin_module = SinModule()
1282+
model_inputs = (torch.ones(1),)
1283+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1284+
max_value = model_inputs[0].shape[0]
1285+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
1286+
lowered_sin_module = to_backend(
1287+
"BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
1288+
)
1289+
1290+
# Check that the lowered module has delegate_info in its meta
1291+
self.assertIn("delegate_info", lowered_sin_module.meta.keys())
1292+
self.assertEqual(lowered_sin_module.meta["delegate_info"], "test")
1293+
1294+
def test_delegate_info_partitioner(self):
1295+
"""
1296+
Test that delegate_info from BackendWithCompilerDemo ends up in the call_delegate node metadata
1297+
when using partitioner-based delegation.
1298+
"""
1299+
1300+
class SinModule(torch.nn.Module):
1301+
def __init__(self):
1302+
super().__init__()
1303+
1304+
def forward(self, x):
1305+
return torch.sin(x)
1306+
1307+
sin_module = SinModule()
1308+
model_inputs = (torch.ones(1),)
1309+
max_value = model_inputs[0].shape[0]
1310+
1311+
partitioner = AllNodePartitioner(
1312+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1313+
)
1314+
1315+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1316+
lowered_m = edgeir_m.to_backend(partitioner)
1317+
1318+
# Check that the lowered submodule has delegate_info in its meta
1319+
lowered_submodules = get_lowered_submodules(
1320+
lowered_m.exported_program().graph_module
1321+
)
1322+
self.assertEqual(len(lowered_submodules), 1)
1323+
1324+
lowered_module = lowered_submodules[0][1]
1325+
self.assertIn("delegate_info", lowered_module.meta)
1326+
self.assertEqual(lowered_module.meta["delegate_info"], "test")

0 commit comments

Comments
 (0)