@@ -17,6 +17,10 @@ def forward(self, values, incr: Optional[List[int]]):
1717
1818class TestNativeFunctions (TestCase ):
1919
20+ #
21+ # optional float list
22+ #
23+
2024 def do_test_optional_floatlist_with_module (self , module ):
2125 values = torch .tensor ([1.5 , 2.5 ], dtype = torch .float )
2226
@@ -66,6 +70,9 @@ def test_optional_floatlist_invalid(self):
6670 with self .assertRaisesRegex (RuntimeError , "value of type .* instead found type" ):
6771 torch .jit .script (FloatListWrapperModule ())(torch .zeros (1 ), torch .zeros (1 ))
6872
73+ #
74+ # optional int list
75+ #
6976
7077 def do_test_optional_intlist_with_module (self , module ):
7178 values = torch .tensor ([1 , 2 ], dtype = torch .int )
@@ -116,6 +123,59 @@ def test_optional_intlist_invalid(self):
116123 with self .assertRaisesRegex (RuntimeError , "value of type .* instead found type" ):
117124 torch .jit .script (IntListWrapperModule ())(torch .zeros (1 ), torch .zeros (1 ))
118125
126+ #
127+ # optional filled int list
128+ #
129+
130+ def do_test_optional_filled_intlist_with_module (self , module ):
131+ values = torch .tensor ([1 , 2 ], dtype = torch .int )
132+
133+ returned = module (values , None )
134+ self .assertEqual (values , returned )
135+ # Make sure that it's an alias, indicating that the operator saw a nullopt.
136+ values [0 ] = 3
137+ self .assertEqual (values , returned )
138+
139+ returned = module (values , 10 )
140+ self .assertEqual (values , torch .tensor ([3 , 2 ], dtype = torch .int ))
141+ self .assertEqual (returned , torch .tensor ([13 , 12 ], dtype = torch .int ))
142+
143+ def trace_optional_filled_intlist (self , const ):
144+ def wrapper (values ):
145+ return torch ._C ._nn ._test_optional_filled_intlist (values , const )
146+ return torch .jit .trace (wrapper , torch .tensor ([1 , 2 ], dtype = torch .int ))
147+
148+ def test_optional_filled_intlist (self ):
149+
150+ def f (n : int ):
151+ x = torch ._C ._nn ._test_optional_filled_intlist (torch .tensor ([1 , 1 ], dtype = torch .int ), (n , n ))
152+ y = torch ._C ._nn ._test_optional_filled_intlist (torch .tensor ([1 , 1 ], dtype = torch .int ), n )
153+ return x , y
154+
155+ # eager
156+ returned = f (10 )
157+ self .assertEqual (returned [0 ], returned [1 ])
158+
159+ # scripted
160+ s = torch .jit .script (f )
161+ returned = s (10 )
162+ self .assertEqual (returned [0 ], returned [1 ])
163+
164+ # traced
165+ traced_none = self .trace_optional_filled_intlist (None )
166+ traced_int = self .trace_optional_filled_intlist (10 )
167+
168+ # Not really a module, just lets us use our two traced functions to handle
169+ # the specific cases of passing None and 10.
170+ def fake_module (values , const ):
171+ if const is None :
172+ return traced_none (values )
173+ if const == 10 :
174+ return traced_int (values )
175+ raise Exception ("Invalid argument" )
176+
177+ self .do_test_optional_filled_intlist_with_module (fake_module )
178+
119179
120180if __name__ == '__main__' :
121181 run_tests ()
0 commit comments