Skip to content

Commit 0df239e

Browse files
Chilleefacebook-github-bot
authored andcommitted
[FX] Make arg normalization a method on Node and not a pass (also augment tests to be exhaustive) (pytorch#55992)
Summary: Commandeered from pytorch#54563 Primary changes from first PR: 1. Refactored primary `normalize_function` logic into `operator_schemas.py` so that non-FX users can use it. 2. Refactored tests a bit, and added a path to call `normalize_function` directly. 3. Moved check for `boolean_dispatch` so that `torch.lu` also gets properly handled. Pull Request resolved: pytorch#55992 Reviewed By: mruberry Differential Revision: D27774396 Pulled By: Chillee fbshipit-source-id: 7f65632e1d608e4abd55aec5ccbfdc3f67f52b8e
1 parent 81b5921 commit 0df239e

File tree

7 files changed

+414
-108
lines changed

7 files changed

+414
-108
lines changed

test/test_fx_experimental.py

Lines changed: 133 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
import unittest
44
import sys
5+
import math
56
from typing import Callable, Dict, Union, List
67
from torch.fx.symbolic_trace import symbolic_trace
78
from torch.fx.graph_module import GraphModule
@@ -12,6 +13,8 @@
1213
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
1314
from torch.testing._internal.common_utils import run_tests
1415
from torch.testing._internal.jit_utils import JitTestCase
16+
from torch.testing._internal.common_methods_invocations import op_db
17+
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
1518
from torch.fx.passes.split_module import split_module
1619
from torch.fx.experimental.partitioner_utils import (
1720
NodeLatency,
@@ -23,10 +26,11 @@
2326
)
2427
import torch.fx.experimental.optimization as optimization
2528
from torch.fx.experimental import merge_matmul
26-
from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
29+
from torch.fx.experimental.normalize import NormalizeOperators, NormalizeArgs
2730
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
2831
from torch.testing._internal.common_nn import module_tests, new_module_tests
29-
from torch.fx.passes.shape_prop import extract_tensor_metadata
32+
from torch.fx.operator_schemas import _torchscript_type_to_python_type, normalize_function, normalize_module
33+
from torch.fx.passes.shape_prop import extract_tensor_metadata, ShapeProp
3034

3135
try:
3236
from torchvision.models import resnet18
@@ -826,24 +830,26 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
826830
input = torch.randn(5, 3, 224, 224)
827831
ref_outs = traced(input)
828832

833+
ShapeProp(traced).propagate(input)
829834
traced = NormalizeArgs(traced).transform()
830835

831-
test_outs = traced(input)
832-
self.assertEqual(test_outs, ref_outs)
833836

834837
modules = dict(traced.named_modules())
838+
835839
for node in traced.graph.nodes:
836-
if node.op == 'call_function' and node.target.__module__ == 'torch.nn.functional':
840+
if node.op == 'call_function' and node.target != operator.add:
837841
self.assertEqual(len(node.args), 0)
838-
if node.op == 'call_module':
842+
elif node.op == 'call_module':
839843
submod_class = modules[node.target].__class__
840844
nn_class = getattr(torch.nn, submod_class.__name__)
841845
if submod_class == nn_class:
842846
self.assertEqual(len(node.args), 0)
847+
traced(input)
848+
self.assertEqual(traced(input), ref_outs)
843849

844850
def test_normalize_modules_exhaustive(self):
845851
"""
846-
Exhaustively test `NormalizeArgs` on all standard
852+
Exhaustively test `Node.normalized_arguments` on all standard
847853
torch.nn Module classes
848854
"""
849855
for test_params in module_tests + new_module_tests:
@@ -892,8 +898,23 @@ def forward(self, {params}):
892898
test_instance = gbls[test_classname](mod)
893899
traced = symbolic_trace(test_instance)
894900

895-
# Now actually test arg normalization!
896-
traced = NormalizeArgs(traced).transform()
901+
# Use `Node.normalized_arguments` to get a new set of arguments
902+
# to feed to the Module. Then, rewrite the node to only take
903+
# in those arguments as kwargs
904+
modules = dict(traced.named_modules())
905+
for node in traced.graph.nodes:
906+
if node.op == 'call_module':
907+
submod_class = modules[node.target].__class__
908+
nn_class = getattr(torch.nn, submod_class.__name__)
909+
if submod_class == nn_class:
910+
normalized_args = node.normalized_arguments(traced)
911+
normalized_args2 = normalize_module(traced, node.target, node.args, node.kwargs)
912+
assert(normalized_args == normalized_args2)
913+
assert normalized_args
914+
node.args = ()
915+
node.kwargs = normalized_args
916+
917+
traced.recompile()
897918

898919
# These Modules have an RNG in their forward, so testing
899920
# correctness by comparing outputs is not correct. Skip that
@@ -904,7 +925,7 @@ def forward(self, {params}):
904925
if mod.__class__.__name__ not in stochastic_modules:
905926
self.assertEqual(traced(*inputs), mod(*inputs))
906927

907-
# Ensure all args/kwargs are normalized into kwargs
928+
traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
908929
modules = dict(traced.named_modules())
909930
for node in traced.graph.nodes:
910931
if node.op == 'call_module':
@@ -913,6 +934,8 @@ def forward(self, {params}):
913934
if submod_class == nn_class:
914935
self.assertEqual(len(node.args), 0)
915936

937+
938+
916939
@skipIfNoTorchVision
917940
def test_annotate_returns_with_schema(self):
918941
m = resnet18()
@@ -1236,6 +1259,106 @@ def test_prepare_for_inference_cpu_torchvision(self):
12361259
torch.testing.assert_allclose(orig_out, new_out)
12371260

12381261

1262+
class TestNormalizeOperators(JitTestCase):
1263+
@onlyCPU
1264+
@ops(op_db, allowed_dtypes=(torch.float,))
1265+
def test_normalize_operator_exhaustive(self, device, dtype, op):
1266+
# Unsupported input types
1267+
if op.name in {'index_put', '__getitem__', 'unfold', 'repeat', 'polygamma'}:
1268+
return
1269+
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
1270+
fx_fail = {'stack', 'hstack', 'vstack', 'dstack',
1271+
'linalg.multi_dot'}
1272+
print(op.name)
1273+
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1274+
for sample_input in sample_inputs_itr:
1275+
unsupported_arg_type = False
1276+
arg_values = [sample_input.input] + list(sample_input.args)
1277+
kwarg_values = sample_input.kwargs
1278+
arg_types = []
1279+
kwarg_types = {}
1280+
1281+
def jit_infer_type(v):
1282+
inferred_arg_type = torch._C._jit_try_infer_type(v)
1283+
assert inferred_arg_type.success()
1284+
t = _torchscript_type_to_python_type(inferred_arg_type.type())
1285+
return t
1286+
1287+
for v in arg_values:
1288+
if isinstance(v, torch.Tensor):
1289+
arg_types.append(type(v))
1290+
else:
1291+
if isinstance(v, complex):
1292+
# Complex type not supported in FX
1293+
unsupported_arg_type = True
1294+
arg_types.append(jit_infer_type(v))
1295+
1296+
for k, v in kwarg_values.items():
1297+
if isinstance(v, torch.Tensor):
1298+
kwarg_types[k] = type(v)
1299+
else:
1300+
if isinstance(v, complex):
1301+
# Complex type not supported in FX
1302+
unsupported_arg_type = True
1303+
kwarg_types[k] = jit_infer_type(v)
1304+
1305+
if unsupported_arg_type:
1306+
continue
1307+
# Test normalize_function by itself
1308+
ref_out = op.op(*arg_values, **kwarg_values)
1309+
norm_kwargs = normalize_function(op.op, arg_values, kwarg_values, arg_types, kwarg_types)
1310+
test_out = op.op(**norm_kwargs)
1311+
self.assertEqual(test_out, ref_out)
1312+
1313+
# Test normalized_arguments as part of FX
1314+
if op.name in fx_fail:
1315+
continue
1316+
param_names = []
1317+
param_values = []
1318+
fx_args = []
1319+
for idx, v in enumerate(arg_values):
1320+
if isinstance(v, torch.Tensor):
1321+
param_names.append(f"arg_{idx}")
1322+
param_values.append(v)
1323+
fx_args.append(param_names[-1])
1324+
else:
1325+
fx_args.append(f'{repr(v)}')
1326+
1327+
for k, v in kwarg_values.items():
1328+
if isinstance(v, torch.Tensor):
1329+
param_names.append(k)
1330+
param_values.append(v)
1331+
fx_args.append(k)
1332+
else:
1333+
fx_args.append(f'{k} = {repr(v)}')
1334+
1335+
code = f"""
1336+
class TestModule(torch.nn.Module):
1337+
def forward(self, {', '.join(param_names)}):
1338+
return torch.{op.name}({', '.join(fx_args)})
1339+
"""
1340+
1341+
g = {'torch': torch, 'inf' : math.inf}
1342+
exec(code, g)
1343+
TestModule = g['TestModule']
1344+
1345+
1346+
m = TestModule()
1347+
traced = torch.fx.symbolic_trace(m)
1348+
ref_out = traced(*param_values)
1349+
1350+
for node in traced.graph.nodes:
1351+
if node.op == 'call_function':
1352+
normalized_args = node.normalized_arguments(traced, arg_types, kwarg_types)
1353+
assert normalized_args
1354+
node.args = ()
1355+
node.kwargs = normalized_args
1356+
traced.recompile()
1357+
1358+
test_out = traced(*param_values)
1359+
self.assertEqual(test_out, ref_out)
1360+
1361+
instantiate_device_type_tests(TestNormalizeOperators, globals())
12391362

12401363
if __name__ == "__main__":
12411364
run_tests()

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ class DictType(JitType):
907907

908908
class TupleType(JitType):
909909
def __init__(self, a: List[JitType]) -> None: ...
910+
def elements(self) -> List[JitType]: ...
910911

911912
class ClassType(JitType):
912913
def __init__(self, qualified_name: str) -> None: ...

torch/fx/experimental/normalize.py

Lines changed: 37 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
22
import torch.fx
3-
import inspect
3+
import torch.fx as fx
44
import operator
5-
from typing import Any, Callable, Dict, Optional, Tuple
6-
from torch.fx.node import Argument, Target
7-
from torch._jit_internal import boolean_dispatched
5+
from typing import Any, Callable, Dict, Tuple, Optional
6+
from torch.fx.node import Argument, Target, Node
7+
from torch.fx.operator_schemas import normalize_module, normalize_function, create_type_hint
88

9-
from torch.fx import Transformer
10-
from torch.fx.operator_schemas import get_signature_for_torch_op
9+
from torch.fx import Transformer, Proxy
1110
from .schema_type_annotation import AnnotateTypesWithSchema
1211

1312
class NormalizeArgs(Transformer):
@@ -18,108 +17,53 @@ class NormalizeArgs(Transformer):
1817
Also populates default values. Does not support positional-only
1918
parameters or varargs parameters (*args, **kwargs).
2019
21-
Example usage:
20+
If the nodes have 'type' metadata, it will use it to disambiguate
21+
overloads. Otherwise, it will throw an error.
2222
23+
Example usage:
2324
m = torchvision.models.resnet18()
24-
2525
traced = torch.fx.symbolic_trace(m)
26-
2726
traced = NormalizeArgs(traced).transform()
2827
"""
29-
def __init__(self, module : torch.nn.Module, normalize_functionals : bool = True,
30-
normalize_modules : bool = True):
28+
def __init__(self, module : torch.nn.Module):
3129
super().__init__(module)
32-
self.normalize_functionals = normalize_functionals
33-
self.normalize_modules = normalize_modules
30+
self.node_map: Dict[Proxy, Node] = {}
3431

35-
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
36-
new_kwargs = None
37-
38-
if self.normalize_functionals and target.__module__ == 'torch.nn.functional':
39-
target_for_analysis = target
40-
if target in boolean_dispatched:
41-
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
42-
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
43-
# branches of the dispatch have exactly the same signature. If they do, use the `true`
44-
# branch signature for analysis. Otherwise, leave this un-normalized
45-
assert not isinstance(target, str)
46-
dispatched = boolean_dispatched[target]
47-
if_true, if_false = dispatched['if_true'], dispatched['if_false']
48-
if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
49-
return super().call_function(target, args, kwargs)
50-
target_for_analysis = if_true
51-
52-
assert callable(target_for_analysis)
53-
sig = inspect.signature(inspect.unwrap(target_for_analysis))
54-
new_kwargs = self._args_kwargs_to_normalized_kwargs(sig, args, kwargs)
32+
def run_node(self, n: Node) -> Any:
33+
args, kwargs = self.fetch_args_kwargs_from_env(n)
34+
35+
def get_type(arg):
36+
if isinstance(arg, fx.Proxy):
37+
old_meta = self.node_map[arg].meta
38+
return old_meta['type'] if 'type' in old_meta else None
39+
return create_type_hint(arg)
40+
41+
arg_types = tuple([get_type(arg) for arg in args])
42+
kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
43+
if n.op == 'call_function':
44+
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
5545
else:
56-
assert callable(target)
57-
torch_op_schemas = get_signature_for_torch_op(target)
58-
if torch_op_schemas:
59-
# Iterate through all of the schema until we find one that matches
60-
# If one matches, populate `new_kwargs` with the combined args/kwargs
61-
# values. If none matches, `new_kwargs` will be None
62-
for candidate_signature in torch_op_schemas:
63-
try:
64-
candidate_signature.bind(args, kwargs)
65-
new_kwargs = self._args_kwargs_to_normalized_kwargs(candidate_signature, args, kwargs)
66-
break
67-
except TypeError:
68-
continue
46+
out = super().run_node(n)
47+
self.node_map[out] = n
48+
return out
49+
50+
def call_function(
51+
self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any],
52+
arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types : Optional[Dict[str, Any]] = None):
53+
assert callable(target)
54+
new_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types) # type: ignore
6955
if new_kwargs:
70-
# FIXME: `target(**kwargs)` doesn't keep things specified as kwargs
71-
# in kwargs
7256
return self.tracer.create_proxy('call_function', target, (), new_kwargs)
7357
else:
7458
return super().call_function(target, args, kwargs)
7559

7660
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
7761
assert isinstance(target, str)
78-
submod = self.fetch_attr(target)
79-
if self.normalize_modules and hasattr(submod.__class__, '__name__'):
80-
classname = submod.__class__.__name__
81-
if getattr(torch.nn, classname, None) == submod.__class__:
82-
sig = inspect.signature(inspect.unwrap(submod.forward))
83-
new_kwargs = self._args_kwargs_to_normalized_kwargs(sig, args, kwargs)
84-
if new_kwargs:
85-
return super().call_module(target, (), new_kwargs)
86-
return super().call_module(target, args, kwargs)
87-
88-
def _args_kwargs_to_normalized_kwargs(self, sig : inspect.Signature, args : Tuple[Argument, ...],
89-
kwargs : Dict[str, Any]) -> Optional[Dict[str, Any]]:
90-
"""
91-
Given a call target, args, and kwargs, return the arguments normalized into
92-
a single kwargs dict, or None if the type signature is not supported by
93-
this normalization.
94-
95-
Args:
96-
97-
target (inspect.Signature): Signature object for the target
98-
args (Tuple): Arguments that appear at the callsite for `target`
99-
kwargs (Dict): Keyword arugments that appear at the callsite for `target`
100-
101-
Returns:
102-
103-
Optional[Dict]: Normalized kwargs for `target`, or `None` if this target is not
104-
supported
105-
"""
106-
107-
# Don't currently support positional-only
108-
# or varargs (*args, **kwargs) signatures
109-
supported_parameter_types = {
110-
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
111-
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
112-
return None
113-
114-
bound_args = sig.bind(*args, **kwargs)
115-
bound_args.apply_defaults()
116-
117-
new_kwargs : Dict[str, Any] = {}
118-
for param in sig.parameters:
119-
new_kwargs[param] = bound_args.arguments[param]
120-
121-
return new_kwargs
122-
62+
new_kwargs = normalize_module(self.module, target, args, kwargs) # type: ignore
63+
if new_kwargs:
64+
return super().call_module(target, (), new_kwargs)
65+
else:
66+
return super().call_module(target, args, kwargs)
12367

12468
class NormalizeOperators(AnnotateTypesWithSchema):
12569
"""

0 commit comments

Comments
 (0)