Skip to content

Commit 5c64b10

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
let backends set info on the lowered_backend.meta from preprocess (#13856)
Summary: We want some delegates to be able to test experimental features that interact with the core runtime. Typically when doing this there is an expectation of tight coupling in the full lowering process so we arent super concerned about UX. The backend api is pretty hard to change because the code is mostly owned by 3rd parties, so we want to be really sure when we change it that its correct. By exposing this module.meta hookup we can test things in a controled maner without changing the blessed api before we are ready. Reviewed By: mergennachin, hsharma35 Differential Revision: D81466391
1 parent b2a8550 commit 5c64b10

File tree

5 files changed

+79
-2
lines changed

5 files changed

+79
-2
lines changed

exir/backend/backend_api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager, nullcontext
1111
from dataclasses import dataclass
1212
from functools import singledispatch
13-
from typing import Dict, Generator, List, Mapping
13+
from typing import Any, Dict, Generator, List, Mapping
1414

1515
import torch
1616

@@ -126,6 +126,10 @@ def to_backend(
126126
lowered_module.meta = {
127127
"debug_handle_map": preprocess_result.debug_handle_map
128128
}
129+
if preprocess_result._delegate_info_meta is not None:
130+
lowered_module.meta["_delegate_info_meta"] = (
131+
preprocess_result._delegate_info_meta
132+
)
129133
return lowered_module
130134
raise NotImplementedError(f"Backend {backend_id} was not found.")
131135

@@ -610,6 +614,11 @@ def lower_all_submodules_to_backend(
610614
lowered_module.meta = {
611615
"debug_handle_map": preprocess_result.debug_handle_map,
612616
}
617+
if preprocess_result._delegate_info_meta is not None:
618+
assert lowered_module.meta is not None # for pyre
619+
lowered_module.meta["_delegate_info_meta"] = (
620+
preprocess_result._delegate_info_meta
621+
)
613622
is_submodule = call_submodule_node.meta["is_submodule"]
614623
toplevel_input_specs_to_delete = call_submodule_node.meta[
615624
"toplevel_input_specs_to_delete"

exir/backend/backend_details.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99

10-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Optional, Tuple, Union
1111

1212
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1313

@@ -32,6 +32,11 @@ class PreprocessResult:
3232
# but retrieveable by delegates via the NamedDataMap at runtime.
3333
data_store_output: Optional[NamedDataStoreOutput] = None
3434

35+
# Optional delegate-specific information that will be added to the
36+
# lowered_module.meta field in the graph, but not directly serialized
37+
# into the PTE file.
38+
_delegate_info_meta: Optional[Any] = None
39+
3540

3641
"""
3742
How to create a backend (for example, BackendWithCompilerDemo):

exir/backend/test/backend_with_compiler_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,5 @@ def preprocess(
138138
encoding="utf8",
139139
),
140140
debug_handle_map=debug_handle_map,
141+
_delegate_info_meta="test",
141142
)

exir/backend/test/test_backends_lifted.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):
12641264

12651265
gm = to_edge(export(ComposedM(), inputs, strict=True))
12661266
gm.exported_program().module()(*inputs)
1267+
1268+
def test_delegate_info_full_delegate(self):
1269+
"""
1270+
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1271+
when using full delegation (to_backend directly).
1272+
"""
1273+
1274+
class SinModule(torch.nn.Module):
1275+
def __init__(self):
1276+
super().__init__()
1277+
1278+
def forward(self, x):
1279+
return torch.sin(x)
1280+
1281+
sin_module = SinModule()
1282+
model_inputs = (torch.ones(1),)
1283+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1284+
max_value = model_inputs[0].shape[0]
1285+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
1286+
lowered_sin_module = to_backend(
1287+
"BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
1288+
)
1289+
1290+
# Check that the lowered module has _delegate_info_meta in its meta
1291+
self.assertIn("_delegate_info_meta", lowered_sin_module.meta.keys())
1292+
self.assertEqual(lowered_sin_module.meta["_delegate_info_meta"], "test")
1293+
1294+
def test_delegate_info_partitioner(self):
1295+
"""
1296+
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1297+
when using partitioner-based delegation.
1298+
"""
1299+
1300+
class SinModule(torch.nn.Module):
1301+
def __init__(self):
1302+
super().__init__()
1303+
1304+
def forward(self, x):
1305+
return torch.sin(x)
1306+
1307+
sin_module = SinModule()
1308+
model_inputs = (torch.ones(1),)
1309+
max_value = model_inputs[0].shape[0]
1310+
1311+
partitioner = AllNodePartitioner(
1312+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1313+
)
1314+
1315+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1316+
lowered_m = edgeir_m.to_backend(partitioner)
1317+
1318+
# Check that the lowered submodule has _delegate_info_meta in its meta
1319+
lowered_submodules = get_lowered_submodules(
1320+
lowered_m.exported_program().graph_module
1321+
)
1322+
self.assertEqual(len(lowered_submodules), 1)
1323+
1324+
lowered_module = lowered_submodules[0][1]
1325+
self.assertIn("_delegate_info_meta", lowered_module.meta)
1326+
self.assertEqual(lowered_module.meta["_delegate_info_meta"], "test")

exir/lowered_backend_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class LoweredBackendModule(torch.nn.Module):
6666
_named_data_store_output: Optional[
6767
NamedDataStoreOutput
6868
] # Named Data serialized by the backend
69+
meta: Optional[Dict[str, Any]] # Metadata for the lowered module
6970

7071
def __init__(
7172
self,
@@ -81,6 +82,7 @@ def __init__(
8182
self._processed_bytes = processed_bytes
8283
self._compile_specs = compile_specs
8384
self._named_data_store_output = named_data_store_output
85+
self.meta = None
8486

8587
# pyre-ignore
8688
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":

0 commit comments

Comments
 (0)