Skip to content

Commit a5a6a3e

Browse files
Basil Hosmerfacebook-github-bot
authored andcommitted
add support for optional int list with scalar fill (pytorch#43262)
Summary: Pull Request resolved: pytorch#43262 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D23212049 Pulled By: bhosmer fbshipit-source-id: c7ceb2318645c07d36c3f932c981c9ee3c414f82
1 parent f269fb8 commit a5a6a3e

File tree

6 files changed

+100
-21
lines changed

6 files changed

+100
-21
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -357,43 +357,44 @@ inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
357357
// print out Argument, which is compatible with FunctionSchema parser
358358
// full format: Type(alias)? name=default_value
359359
inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
360-
bool optional_type = arg.type()->kind() == OptionalType::Kind;
360+
361361
// for adjusting the ? position.
362362
// in schema, we have Tensor?(a!) input, and t(a!)?.
363363
// however, t?(a!) doesn't work with schema parser.
364364
// so we always use Type(alias)? format
365-
std::stringstream oss;
366-
if (auto list = arg.type()->cast<c10::ListType>()) {
367-
oss << list->getElementType()->str();
368-
oss << "[";
369-
if (arg.N()) {
370-
oss << *arg.N();
371-
}
372-
oss << "]";
365+
auto type = arg.type();
366+
bool is_opt = type->kind() == OptionalType::Kind;
367+
auto unopt_type = is_opt ? type->cast<OptionalType>()->getElementType() : type;
368+
369+
if (unopt_type->kind() == ListType::Kind && arg.N()) {
370+
// sized lists get size N from arg, not type
371+
auto list = unopt_type->cast<c10::ListType>();
372+
out << list->getElementType()->str() << "[" << *arg.N() << "]";
373373
} else {
374-
oss << arg.type()->str();
375-
}
376-
if (optional_type) {
377-
oss.seekp(oss.str().size() - 1);
374+
out << unopt_type->str();
378375
}
376+
379377
if (arg.alias_info()) {
380-
oss << arg.alias_info().value();
378+
out << arg.alias_info().value();
381379
}
382-
if (optional_type) {
383-
oss << "?";
380+
381+
if (is_opt) {
382+
out << "?";
384383
}
385-
out << oss.str();
384+
386385
if (!arg.name().empty()) {
387386
out << " " << arg.name();
388387
}
388+
389389
if (arg.default_value()) {
390390
out << "=";
391-
if (arg.type()->kind() == c10::TypeKind::StringType) {
392-
printQuotedString(out, arg.default_value().value().toStringRef());
391+
if (type->kind() == c10::TypeKind::StringType) {
392+
printQuotedString(out, arg.default_value().value().toStringRef());
393393
} else {
394394
out << arg.default_value().value();
395395
}
396396
}
397+
397398
return out;
398399
}
399400

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7291,6 +7291,13 @@
72917291
dispatch:
72927292
CPU: _test_optional_intlist
72937293

7294+
# Note: this function is only for testing.
7295+
- func: _test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor
7296+
use_c10_dispatcher: full
7297+
python_module: nn
7298+
dispatch:
7299+
CPU: _test_optional_intlist
7300+
72947301
# Note: this function is only for testing.
72957302
- func: _test_optional_floatlist(Tensor values, float[]? addends) -> Tensor
72967303
use_c10_dispatcher: full

aten/src/ATen/native_parse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def type_argument_translations(arg):
8989
raise RuntimeError("Please use float and not double. "
9090
"See [temp translations] for details.")
9191
# Enables int[x] by translating to legacy IntArrayRef[x]. See [temp translations]
92+
elif re.match(r'int\[(\d+)\]\?', t):
93+
match = re.match(r'int\[(\d+)\]\?', t)
94+
t = 'IntArrayRef'
95+
size = int(match.group(1))
9296
elif re.match(r'int\[(\d+)\]', t):
9397
match = re.match(r'int\[(\d+)\]', t)
9498
t = 'IntArrayRef'

test/test_native_functions.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def forward(self, values, incr: Optional[List[int]]):
1717

1818
class TestNativeFunctions(TestCase):
1919

20+
#
21+
# optional float list
22+
#
23+
2024
def do_test_optional_floatlist_with_module(self, module):
2125
values = torch.tensor([1.5, 2.5], dtype=torch.float)
2226

@@ -66,6 +70,9 @@ def test_optional_floatlist_invalid(self):
6670
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
6771
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
6872

73+
#
74+
# optional int list
75+
#
6976

7077
def do_test_optional_intlist_with_module(self, module):
7178
values = torch.tensor([1, 2], dtype=torch.int)
@@ -116,6 +123,59 @@ def test_optional_intlist_invalid(self):
116123
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
117124
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
118125

126+
#
127+
# optional filled int list
128+
#
129+
130+
def do_test_optional_filled_intlist_with_module(self, module):
131+
values = torch.tensor([1, 2], dtype=torch.int)
132+
133+
returned = module(values, None)
134+
self.assertEqual(values, returned)
135+
# Make sure that it's an alias, indicating that the operator saw a nullopt.
136+
values[0] = 3
137+
self.assertEqual(values, returned)
138+
139+
returned = module(values, 10)
140+
self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
141+
self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int))
142+
143+
def trace_optional_filled_intlist(self, const):
144+
def wrapper(values):
145+
return torch._C._nn._test_optional_filled_intlist(values, const)
146+
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
147+
148+
def test_optional_filled_intlist(self):
149+
150+
def f(n: int):
151+
x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
152+
y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
153+
return x, y
154+
155+
# eager
156+
returned = f(10)
157+
self.assertEqual(returned[0], returned[1])
158+
159+
# scripted
160+
s = torch.jit.script(f)
161+
returned = s(10)
162+
self.assertEqual(returned[0], returned[1])
163+
164+
# traced
165+
traced_none = self.trace_optional_filled_intlist(None)
166+
traced_int = self.trace_optional_filled_intlist(10)
167+
168+
# Not really a module, just lets us use our two traced functions to handle
169+
# the specific cases of passing None and 10.
170+
def fake_module(values, const):
171+
if const is None:
172+
return traced_none(values)
173+
if const == 10:
174+
return traced_int(values)
175+
raise Exception("Invalid argument")
176+
177+
self.do_test_optional_filled_intlist_with_module(fake_module)
178+
119179

120180
if __name__ == '__main__':
121181
run_tests()

tools/autograd/gen_python_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def create_python_bindings(python_functions, is_python_method, module):
302302
'TensorList': 'tensorlist_n<{}>',
303303
'DimnameList': 'dimnamelist',
304304
'IntArrayRef': 'intlist',
305+
'c10::optional<IntArrayRef>': 'intlistOptional',
305306
}
306307

307308
UNPACK_WITH_DEFAULT_METHODS = {
@@ -1228,7 +1229,10 @@ def get_schema_formal(arg, is_python_method):
12281229

12291230
size = arg.get('size')
12301231
if size is not None:
1231-
typename = '{}[{}]'.format(typename, size)
1232+
if typename.endswith('?'):
1233+
typename = '{}[{}]?'.format(typename[:-1], size)
1234+
else:
1235+
typename = '{}[{}]'.format(typename, size)
12321236

12331237
# default
12341238
default = arg.get('default')

torch/csrc/jit/frontend/function_schema_parser.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using c10::ListType;
1818
using c10::make_left;
1919
using c10::make_right;
2020
using c10::OperatorName;
21+
using c10::OptionalType;
2122

2223
namespace torch {
2324
namespace jit {
@@ -109,7 +110,6 @@ struct SchemaParser {
109110
}
110111

111112
Argument parseArgument(size_t idx, bool is_return, bool kwarg_only) {
112-
Argument result;
113113
auto p = type_parser.parseType();
114114
auto type = std::move(p.first);
115115
auto alias_info = std::move(p.second);
@@ -127,6 +127,9 @@ struct SchemaParser {
127127
container->addContainedType(std::move(*alias_info));
128128
}
129129
alias_info = std::move(container);
130+
if (L.nextIf('?')) {
131+
type = OptionalType::create(type);
132+
}
130133
}
131134
if (is_return) {
132135
// optionally field names in return values

0 commit comments

Comments
 (0)