Skip to content

Commit 28c5d90

Browse files
Meghan Lelefacebook-github-bot
authored andcommitted
[JIT] Allow implicit boolean conversion of containers (pytorch#51683)
Summary: Pull Request resolved: pytorch#51683 **Summary** This commit enables implicit boolean conversion of lists, strings, and dictionaries in conditional expressions. Like Python, empty lists, strings and dictionaries evaluate to `False` and their non-empty counterparts evaluate to `True`. This allows users to write code like ``` torch.jit.script def fn(l: List[int]): if l: ... else: ... ``` This has been requested by some users and would be a good usability improvement. **Test Plan** This commit adds unit tests to `TestList`, `TestDict` and `test_jit_string.py` to test this new feature. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26264410 Pulled By: SplitInfinity fbshipit-source-id: b764c18fd766cfc128ea98a02b7c6c3fa49f8632
1 parent d3023d8 commit 28c5d90

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

test/jit/test_list_dict.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,35 @@
2020
"instead.")
2121

2222
class TestList(JitTestCase):
23+
def test_list_bool_conversion(self):
24+
def if_predicate(l: List[int]):
25+
if l:
26+
s = 0
27+
for n in l:
28+
s += n
29+
30+
return s
31+
else:
32+
return -1
33+
34+
self.checkScript(if_predicate, ([1, 2, 3],))
35+
self.checkScript(if_predicate, ([],))
36+
37+
def while_predicate(l: List[int]):
38+
s = 0
39+
40+
while l:
41+
s += l.pop()
42+
43+
self.checkScript(while_predicate, ([1, 2, 3],))
44+
self.checkScript(while_predicate, ([],))
45+
46+
def ternary_predicate(l: List[int]):
47+
return "non-empty" if l else "empty"
48+
49+
self.checkScript(ternary_predicate, ([1, 2, 3],))
50+
self.checkScript(ternary_predicate, ([],))
51+
2352
def test_in_check(self):
2453
def int_in(x: List[int]) -> bool:
2554
return 2 in x
@@ -1175,6 +1204,34 @@ def dict2(self):
11751204
def dict_bool(self):
11761205
return {True: 1}
11771206

1207+
def test_dict_bool_conversion(self):
1208+
def if_predicate(d: Dict[int, int]):
1209+
if d:
1210+
s, t = 0, 0
1211+
for k, v in d.items():
1212+
s += k
1213+
t += v
1214+
1215+
return s, t
1216+
else:
1217+
return -1, -1
1218+
1219+
self.checkScript(if_predicate, ({1: 2, 3: 5},))
1220+
self.checkScript(if_predicate, ({},))
1221+
1222+
def while_predicate(d: Dict[int, int]):
1223+
while d:
1224+
d.clear()
1225+
1226+
self.checkScript(while_predicate, ({1: 2, 3: 5},))
1227+
self.checkScript(while_predicate, ({},))
1228+
1229+
def ternary_predicate(d: Dict[int, int]):
1230+
return "non-empty" if d else "empty"
1231+
1232+
self.checkScript(ternary_predicate, ({1: 2, 3: 5},))
1233+
self.checkScript(ternary_predicate, ({},))
1234+
11781235
def test_del(self):
11791236
def inputs():
11801237
return {'hi': 2, 'bye': 3}

test/test_jit_string.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,5 +306,14 @@ def test_str_join():
306306
)
307307
self.checkScript(test_str_join, ())
308308

309+
def test_bool_conversion(a: str):
310+
if a:
311+
return a
312+
else:
313+
return "default"
314+
315+
self.checkScript(test_bool_conversion, ("nonempty",))
316+
self.checkScript(test_bool_conversion, ("",))
317+
309318
if __name__ == '__main__':
310319
run_tests()

torch/csrc/jit/frontend/sugared_value.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,20 @@ struct TORCH_API CastValue : public BuiltinFunction {
511511
at::ArrayRef<NamedValue> kwargs,
512512
size_t n_binders) override {
513513
if (args.size() == 1 && kwargs.size() == 0) {
514+
auto len_op = std::make_shared<BuiltinFunction>(aten::len, at::nullopt);
515+
auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, at::nullopt);
516+
auto zero = m.graph()->insertConstant(0);
517+
514518
auto v = args[0].value(*m.graph());
515519
if (v->type()->isSubtypeOf(type_)) {
516520
return std::make_shared<SimpleValue>(v);
521+
} else if (
522+
*type_ == *BoolType::get() &&
523+
(v->type()->isSubtypeOf(AnyListType::get()) ||
524+
v->type()->isSubtypeOf(StringType::get()) ||
525+
v->type()->cast<DictType>())) {
526+
auto len = len_op->call(loc, m, {v}, {}, 1);
527+
return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1);
517528
}
518529
}
519530
return BuiltinFunction::call(loc, m, args, kwargs, n_binders);

0 commit comments

Comments
 (0)