22import operator
33import unittest
44import sys
5+ import math
56from typing import Callable , Dict , Union , List
67from torch .fx .symbolic_trace import symbolic_trace
78from torch .fx .graph_module import GraphModule
1213from torch .fx .experimental .param_fetch import lift_lowering_attrs_to_nodes
1314from torch .testing ._internal .common_utils import run_tests
1415from torch .testing ._internal .jit_utils import JitTestCase
16+ from torch .testing ._internal .common_methods_invocations import op_db
17+ from torch .testing ._internal .common_device_type import ops , onlyCPU , instantiate_device_type_tests
1518from torch .fx .passes .split_module import split_module
1619from torch .fx .experimental .partitioner_utils import (
1720 NodeLatency ,
2326)
2427import torch .fx .experimental .optimization as optimization
2528from torch .fx .experimental import merge_matmul
26- from torch .fx .experimental .normalize import NormalizeArgs , NormalizeOperators
29+ from torch .fx .experimental .normalize import NormalizeOperators , NormalizeArgs
2730from torch .fx .experimental .schema_type_annotation import AnnotateTypesWithSchema
2831from torch .testing ._internal .common_nn import module_tests , new_module_tests
29- from torch .fx .passes .shape_prop import extract_tensor_metadata
32+ from torch .fx .operator_schemas import _torchscript_type_to_python_type , normalize_function , normalize_module
33+ from torch .fx .passes .shape_prop import extract_tensor_metadata , ShapeProp
3034
3135try :
3236 from torchvision .models import resnet18
@@ -826,24 +830,26 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
826830 input = torch .randn (5 , 3 , 224 , 224 )
827831 ref_outs = traced (input )
828832
833+ ShapeProp (traced ).propagate (input )
829834 traced = NormalizeArgs (traced ).transform ()
830835
831- test_outs = traced (input )
832- self .assertEqual (test_outs , ref_outs )
833836
834837 modules = dict (traced .named_modules ())
838+
835839 for node in traced .graph .nodes :
836- if node .op == 'call_function' and node .target . __module__ == 'torch.nn.functional' :
840+ if node .op == 'call_function' and node .target != operator . add :
837841 self .assertEqual (len (node .args ), 0 )
838- if node .op == 'call_module' :
842+ elif node .op == 'call_module' :
839843 submod_class = modules [node .target ].__class__
840844 nn_class = getattr (torch .nn , submod_class .__name__ )
841845 if submod_class == nn_class :
842846 self .assertEqual (len (node .args ), 0 )
847+ traced (input )
848+ self .assertEqual (traced (input ), ref_outs )
843849
844850 def test_normalize_modules_exhaustive (self ):
845851 """
846- Exhaustively test `NormalizeArgs ` on all standard
852+ Exhaustively test `Node.normalized_arguments ` on all standard
847853 torch.nn Module classes
848854 """
849855 for test_params in module_tests + new_module_tests :
@@ -892,8 +898,23 @@ def forward(self, {params}):
892898 test_instance = gbls [test_classname ](mod )
893899 traced = symbolic_trace (test_instance )
894900
895- # Now actually test arg normalization!
896- traced = NormalizeArgs (traced ).transform ()
901+ # Use `Node.normalized_arguments` to get a new set of arguments
902+ # to feed to the Module. Then, rewrite the node to only take
903+ # in those arguments as kwargs
904+ modules = dict (traced .named_modules ())
905+ for node in traced .graph .nodes :
906+ if node .op == 'call_module' :
907+ submod_class = modules [node .target ].__class__
908+ nn_class = getattr (torch .nn , submod_class .__name__ )
909+ if submod_class == nn_class :
910+ normalized_args = node .normalized_arguments (traced )
911+ normalized_args2 = normalize_module (traced , node .target , node .args , node .kwargs )
912+ assert (normalized_args == normalized_args2 )
913+ assert normalized_args
914+ node .args = ()
915+ node .kwargs = normalized_args
916+
917+ traced .recompile ()
897918
898919 # These Modules have an RNG in their forward, so testing
899920 # correctness by comparing outputs is not correct. Skip that
@@ -904,7 +925,7 @@ def forward(self, {params}):
904925 if mod .__class__ .__name__ not in stochastic_modules :
905926 self .assertEqual (traced (* inputs ), mod (* inputs ))
906927
907- # Ensure all args/kwargs are normalized into kwargs
928+ traced = NormalizeArgs ( symbolic_trace ( test_instance )). transform ()
908929 modules = dict (traced .named_modules ())
909930 for node in traced .graph .nodes :
910931 if node .op == 'call_module' :
@@ -913,6 +934,8 @@ def forward(self, {params}):
913934 if submod_class == nn_class :
914935 self .assertEqual (len (node .args ), 0 )
915936
937+
938+
916939 @skipIfNoTorchVision
917940 def test_annotate_returns_with_schema (self ):
918941 m = resnet18 ()
@@ -1236,6 +1259,106 @@ def test_prepare_for_inference_cpu_torchvision(self):
12361259 torch .testing .assert_allclose (orig_out , new_out )
12371260
12381261
1262+ class TestNormalizeOperators (JitTestCase ):
1263+ @onlyCPU
1264+ @ops (op_db , allowed_dtypes = (torch .float ,))
1265+ def test_normalize_operator_exhaustive (self , device , dtype , op ):
1266+ # Unsupported input types
1267+ if op .name in {'index_put' , '__getitem__' , 'unfold' , 'repeat' , 'polygamma' }:
1268+ return
1269+ # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
1270+ fx_fail = {'stack' , 'hstack' , 'vstack' , 'dstack' ,
1271+ 'linalg.multi_dot' }
1272+ print (op .name )
1273+ sample_inputs_itr = op .sample_inputs (device , dtype , requires_grad = False )
1274+ for sample_input in sample_inputs_itr :
1275+ unsupported_arg_type = False
1276+ arg_values = [sample_input .input ] + list (sample_input .args )
1277+ kwarg_values = sample_input .kwargs
1278+ arg_types = []
1279+ kwarg_types = {}
1280+
1281+ def jit_infer_type (v ):
1282+ inferred_arg_type = torch ._C ._jit_try_infer_type (v )
1283+ assert inferred_arg_type .success ()
1284+ t = _torchscript_type_to_python_type (inferred_arg_type .type ())
1285+ return t
1286+
1287+ for v in arg_values :
1288+ if isinstance (v , torch .Tensor ):
1289+ arg_types .append (type (v ))
1290+ else :
1291+ if isinstance (v , complex ):
1292+ # Complex type not supported in FX
1293+ unsupported_arg_type = True
1294+ arg_types .append (jit_infer_type (v ))
1295+
1296+ for k , v in kwarg_values .items ():
1297+ if isinstance (v , torch .Tensor ):
1298+ kwarg_types [k ] = type (v )
1299+ else :
1300+ if isinstance (v , complex ):
1301+ # Complex type not supported in FX
1302+ unsupported_arg_type = True
1303+ kwarg_types [k ] = jit_infer_type (v )
1304+
1305+ if unsupported_arg_type :
1306+ continue
1307+ # Test normalize_function by itself
1308+ ref_out = op .op (* arg_values , ** kwarg_values )
1309+ norm_kwargs = normalize_function (op .op , arg_values , kwarg_values , arg_types , kwarg_types )
1310+ test_out = op .op (** norm_kwargs )
1311+ self .assertEqual (test_out , ref_out )
1312+
1313+ # Test normalized_arguments as part of FX
1314+ if op .name in fx_fail :
1315+ continue
1316+ param_names = []
1317+ param_values = []
1318+ fx_args = []
1319+ for idx , v in enumerate (arg_values ):
1320+ if isinstance (v , torch .Tensor ):
1321+ param_names .append (f"arg_{ idx } " )
1322+ param_values .append (v )
1323+ fx_args .append (param_names [- 1 ])
1324+ else :
1325+ fx_args .append (f'{ repr (v )} ' )
1326+
1327+ for k , v in kwarg_values .items ():
1328+ if isinstance (v , torch .Tensor ):
1329+ param_names .append (k )
1330+ param_values .append (v )
1331+ fx_args .append (k )
1332+ else :
1333+ fx_args .append (f'{ k } = { repr (v )} ' )
1334+
1335+ code = f"""
1336+ class TestModule(torch.nn.Module):
1337+ def forward(self, { ', ' .join (param_names )} ):
1338+ return torch.{ op .name } ({ ', ' .join (fx_args )} )
1339+ """
1340+
1341+ g = {'torch' : torch , 'inf' : math .inf }
1342+ exec (code , g )
1343+ TestModule = g ['TestModule' ]
1344+
1345+
1346+ m = TestModule ()
1347+ traced = torch .fx .symbolic_trace (m )
1348+ ref_out = traced (* param_values )
1349+
1350+ for node in traced .graph .nodes :
1351+ if node .op == 'call_function' :
1352+ normalized_args = node .normalized_arguments (traced , arg_types , kwarg_types )
1353+ assert normalized_args
1354+ node .args = ()
1355+ node .kwargs = normalized_args
1356+ traced .recompile ()
1357+
1358+ test_out = traced (* param_values )
1359+ self .assertEqual (test_out , ref_out )
1360+
1361+ instantiate_device_type_tests (TestNormalizeOperators , globals ())
12391362
12401363if __name__ == "__main__" :
12411364 run_tests ()
0 commit comments