Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def to_backend(
lowered_module.meta = {
"debug_handle_map": preprocess_result.debug_handle_map
}
if preprocess_result._delegate_info_meta is not None:
lowered_module.meta["_delegate_info_meta"] = (
preprocess_result._delegate_info_meta
)
return lowered_module
raise NotImplementedError(f"Backend {backend_id} was not found.")

Expand Down Expand Up @@ -610,6 +614,11 @@ def lower_all_submodules_to_backend(
lowered_module.meta = {
"debug_handle_map": preprocess_result.debug_handle_map,
}
if preprocess_result._delegate_info_meta is not None:
assert lowered_module.meta is not None
lowered_module.meta["_delegate_info_meta"] = (
preprocess_result._delegate_info_meta
)
is_submodule = call_submodule_node.meta["is_submodule"]
toplevel_input_specs_to_delete = call_submodule_node.meta[
"toplevel_input_specs_to_delete"
Expand Down
7 changes: 6 additions & 1 deletion exir/backend/backend_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

Expand All @@ -32,6 +32,11 @@ class PreprocessResult:
# but retrieveable by delegates via the NamedDataMap at runtime.
data_store_output: Optional[NamedDataStoreOutput] = None

# Optional delegate-specific information that will be added to the
# lowered_module.meta field in the graph, but not directly serialized
# into the PTE file.
_delegate_info_meta: Optional[Any] = None


"""
How to create a backend (for example, BackendWithCompilerDemo):
Expand Down
1 change: 1 addition & 0 deletions exir/backend/test/backend_with_compiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ def preprocess(
encoding="utf8",
),
debug_handle_map=debug_handle_map,
_delegate_info_meta="test",
)
60 changes: 60 additions & 0 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):

gm = to_edge(export(ComposedM(), inputs, strict=True))
gm.exported_program().module()(*inputs)

def test_delegate_info_full_delegate(self):
"""
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
when using full delegation (to_backend directly).
"""

class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

sin_module = SinModule()
model_inputs = (torch.ones(1),)
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_sin_module = to_backend(
"BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
)

# Check that the lowered module has _delegate_info_meta in its meta
self.assertIn("_delegate_info_meta", lowered_sin_module.meta.keys())
self.assertEqual(lowered_sin_module.meta["_delegate_info_meta"], "test")

def test_delegate_info_partitioner(self):
"""
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
when using partitioner-based delegation.
"""

class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

sin_module = SinModule()
model_inputs = (torch.ones(1),)
max_value = model_inputs[0].shape[0]

partitioner = AllNodePartitioner(
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
)

edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
lowered_m = edgeir_m.to_backend(partitioner)

# Check that the lowered submodule has _delegate_info_meta in its meta
lowered_submodules = get_lowered_submodules(
lowered_m.exported_program().graph_module
)
self.assertEqual(len(lowered_submodules), 1)

lowered_module = lowered_submodules[0][1]
self.assertIn("_delegate_info_meta", lowered_module.meta)
self.assertEqual(lowered_module.meta["_delegate_info_meta"], "test")
2 changes: 1 addition & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def _add_delegate_map(
code, module hierarchy etc.
"""
delegate_map = {}
if hasattr(lowered_module, "meta"):
if lowered_module.meta is not None:
delegate_map = lowered_module.meta.get("debug_handle_map", {})

self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
Expand Down
3 changes: 2 additions & 1 deletion exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class LoweredBackendModule(torch.nn.Module):
_named_data_store_output: Optional[
NamedDataStoreOutput
] # Named Data serialized by the backend
meta: Optional[Dict[str, Any]] # Metadata for the lowered module

def __init__(
self,
Expand All @@ -81,6 +82,7 @@ def __init__(
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
self._named_data_store_output = named_data_store_output
self.meta = None

# pyre-ignore
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
Expand Down Expand Up @@ -109,7 +111,6 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule"
compile_specs=copy.deepcopy(self._compile_specs, memo),
named_data_store_output=self._named_data_store_output,
)
# pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
res.meta = copy.copy(getattr(self, "meta", {}))
return res

Expand Down
Loading