Skip to content

Commit 52456b2

Browse files
suofacebook-github-bot
authored andcommitted
add hasattr() (pytorch#29332)
Summary: Pull Request resolved: pytorch#29332 Even though we're statically typed, this can be useful, e.g. as shorthand when iterating through a module list. Test Plan: Imported from OSS Differential Revision: D18393097 Pulled By: suo fbshipit-source-id: aa42e955f88d1b8a876d0727055eb596453b9839
1 parent 7a63728 commit 52456b2

File tree

4 files changed

+112
-5
lines changed

4 files changed

+112
-5
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ namespace c10 {
112112
_(prim, CreateObject) \
113113
_(prim, SetAttr) \
114114
_(prim, GetAttr) \
115+
_(prim, HasAttr) \
115116
_(prim, profile) \
116117
_(prim, AddStatValue) \
117118
_(prim, TimePoint) \

test/jit/test_builtins.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
import sys
3+
from typing import List
4+
5+
import torch
6+
7+
# Make the helper files in test/ importable
8+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
9+
sys.path.append(pytorch_test_dir)
10+
from jit_utils import JitTestCase
11+
12+
if __name__ == '__main__':
13+
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
14+
"\tpython test/test_jit.py TESTNAME\n\n"
15+
"instead.")
16+
17+
class TestBuiltins(JitTestCase):
18+
"""
19+
Tests for TorchScript support of Python builtin functions.
20+
"""
21+
def test_has_attr(self):
22+
class HasA(torch.nn.Module):
23+
def __init__(self):
24+
super(HasA, self).__init__()
25+
self.a = 0
26+
27+
class HasB(torch.nn.Module):
28+
def __init__(self):
29+
super(HasB, self).__init__()
30+
self.b = 1
31+
32+
class Mod(torch.nn.Module):
33+
def __init__(self):
34+
super(Mod, self).__init__()
35+
self.mods = torch.nn.ModuleList([HasA(), HasB()])
36+
37+
def forward(self):
38+
# use a list to encode hasattr results
39+
l = torch.jit.annotate(List[int], [])
40+
for mod in self.mods:
41+
l.append(int(hasattr(mod, "a")))
42+
l.append(int(hasattr(mod, "b")))
43+
# actually retrieve the attr to test static refinement
44+
if hasattr(mod, "a"):
45+
l.append(mod.a)
46+
if hasattr(mod, "b"):
47+
l.append(mod.b)
48+
return l
49+
50+
self.checkModule(Mod(), ())
51+
52+
def test_has_attr_invalid_args(self):
53+
class Mod(torch.nn.Module):
54+
def __init__(self):
55+
super(Mod, self).__init__()
56+
self.mod = torch.nn.Linear(1, 1)
57+
58+
def forward(self, name):
59+
# not allowed, `name` must be static.
60+
return hasattr(self.mod, name)
61+
62+
with self.assertRaisesRegex(RuntimeError, "hasattr"):
63+
torch.jit.script(Mod())
64+
65+
class Mod(torch.nn.Module):
66+
def __init__(self):
67+
super(Mod, self).__init__()
68+
69+
def forward(self, name):
70+
# not allowed, `torch.rand` is not a class type
71+
return hasattr(torch.rand(2, 3), name)
72+
73+
with self.assertRaisesRegex(RuntimeError, "hasattr"):
74+
torch.jit.script(Mod())

test/test_jit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from jit.test_custom_operators import TestCustomOperators # noqa: F401
2121
from jit.test_export_modes import TestExportModes # noqa: F401
2222
from jit.test_class_type import TestClassType # noqa: F401
23+
from jit.test_builtins import TestBuiltins # noqa: F401
2324

2425
# Torch
2526
from torch import Tensor
@@ -15943,6 +15944,7 @@ def foo(a):
1594315944

1594415945
with self.assertRaisesRegex(RuntimeError, "Inferred \'a\' to be of type \'Tensor"):
1594515946
foo(1)
15947+
1594615948
# known to be failing in tracer
1594715949
EXCLUDE_TRACED = {
1594815950
# The following fail due to #12024.

torch/csrc/jit/script/compiler.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ struct Environment {
439439
"__str__",
440440
std::make_shared<CastValue>(StringType::get(), aten::str))},
441441
{"getattr", SpecialFormValue::create(prim::GetAttr)},
442+
{"hasattr", SpecialFormValue::create(prim::HasAttr)},
442443
{"isinstance", SpecialFormValue::create(prim::isinstance)},
443444
// todo(zach): remove when we can correctly export torch.full via ONNX
444445
// or we have implicit conversion that can convert numbers to tensors
@@ -1050,10 +1051,15 @@ struct to_ir {
10501051
if (expr.kind() == TK_APPLY) {
10511052
auto apply = Apply(expr);
10521053
auto callee = Apply(expr).callee();
1053-
if (callee.kind() == TK_VAR &&
1054-
Var(callee).name().name() == "isinstance") {
1055-
checkApplyNumInputs(apply, 2);
1056-
return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1054+
if (callee.kind() == TK_VAR) {
1055+
if (Var(callee).name().name() == "isinstance") {
1056+
checkApplyNumInputs(apply, 2);
1057+
return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
1058+
}
1059+
if (Var(callee).name().name() == "hasattr") {
1060+
checkApplyNumInputs(apply, 2);
1061+
return emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
1062+
}
10571063
}
10581064
}
10591065
return CondValue(
@@ -1401,7 +1407,26 @@ struct to_ir {
14011407
}
14021408
}
14031409

1404-
CondValue emitIsInstance(Expr obj, Expr classinfo) {
1410+
CondValue emitHasAttr(const Expr& objExpr, const Expr& attrExpr) {
1411+
auto obj = emitExpr(objExpr);
1412+
const auto& type = obj->type();
1413+
if (attrExpr.kind() != TK_STRINGLITERAL) {
1414+
throw ErrorReport(attrExpr)
1415+
<< "hasattr's second argument must be a string literal";
1416+
}
1417+
auto cls = type->cast<ClassType>();
1418+
if (!cls) {
1419+
throw ErrorReport(objExpr)
1420+
<< "hasattr's first argument must be an object, got "
1421+
<< type->python_str() << " instead";
1422+
}
1423+
1424+
const std::string& name = StringLiteral(attrExpr).text();
1425+
const bool hasAttr = cls->hasAttribute(name);
1426+
return CondValue(*graph, objExpr.range(), hasAttr, {});
1427+
}
1428+
1429+
CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) {
14051430
// turn (float, (int, tuple)) into a flat list of types and type kind
14061431
// category checks: tuple_check = true, types = {float, int}
14071432
struct GatheredTypes {
@@ -2401,6 +2426,11 @@ struct to_ir {
24012426
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
24022427
return std::make_shared<SimpleValue>(result.value());
24032428
}
2429+
case prim::HasAttr: {
2430+
checkApplyNumInputs(apply, 2);
2431+
const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
2432+
return std::make_shared<SimpleValue>(result.value());
2433+
} break;
24042434
// This represents the "__new__" method on classes
24052435
// because it takes a ClassValue as input.
24062436
// So if we see:

0 commit comments

Comments
 (0)