Skip to content

Commit aefa0c8

Browse files
authored
[Relay][dismantler] Added handling of packed func (apache#8004)
Added handling of CallNode objects created via packed functions invocation + test cases. Change-Id: I5374abc59a3b0f79f27364c45f1a5789536df940
1 parent 6f82e98 commit aefa0c8

File tree

4 files changed

+121
-8
lines changed

4 files changed

+121
-8
lines changed

include/tvm/relay/expr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ class Var : public Expr {
227227
class Call;
228228
/*! \brief Call container. */
229229
class CallNode : public ExprNode {
230+
protected:
231+
// CallNode uses own deleter to indirectly call non-recursive destructor
232+
Object::FDeleter saved_deleter_;
233+
static void Deleter_(Object* ptr);
234+
230235
public:
231236
/*!
232237
* \brief The operator(function) being invoked
@@ -290,6 +295,7 @@ class CallNode : public ExprNode {
290295

291296
static constexpr const char* _type_key = "relay.Call";
292297
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
298+
friend class Call;
293299
};
294300

295301
class Call : public Expr {

src/relay/ir/expr.cc

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
115115
n->attrs = std::move(attrs);
116116
n->type_args = std::move(type_args);
117117
n->span = std::move(span);
118+
n->saved_deleter_ = n->deleter_;
119+
n->deleter_ = CallNode::Deleter_;
118120
data_ = std::move(n);
119121
}
120122

@@ -288,16 +290,24 @@ inline void Dismantle(const Expr& expr) {
288290

289291
// special handling
290292
if (const CallNode* op = node.as<CallNode>()) {
291-
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
292-
fpush_to_stack(*it);
293+
// do not process args if used elsewhere
294+
if (op->args.use_count() < 2) {
295+
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
296+
fpush_to_stack(*it);
297+
}
293298
}
294-
fpush_to_stack(op->op);
295299
} else if (const TupleNode* op = node.as<TupleNode>()) {
296-
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
297-
fpush_to_stack(*it);
300+
// do not process fields if used elsewhere
301+
if (op->fields.use_count() < 2) {
302+
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
303+
fpush_to_stack(*it);
304+
}
298305
}
299306
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
300-
fpush_to_stack(op->tuple);
307+
// do not process tuple if used elsewhere
308+
if (op->tuple.use_count() < 2) {
309+
fpush_to_stack(op->tuple);
310+
}
301311
}
302312
}
303313
}
@@ -306,7 +316,6 @@ inline void Dismantle(const Expr& expr) {
306316
/*
307317
* Non-recursive destructor
308318
*/
309-
310319
Call::~Call() {
311320
// attempt to dismantle if referenced one or zero times
312321
if (this->use_count() < 2) {
@@ -316,5 +325,16 @@ Call::~Call() {
316325
}
317326
}
318327

328+
/*
329+
* CallNode's deleter
330+
*/
331+
void CallNode::Deleter_(Object* ptr) {
332+
auto p = reinterpret_cast<CallNode*>(ptr);
333+
// resore original deleter
334+
p->deleter_ = p->saved_deleter_;
335+
// create Call reference in order to invoke ~Call
336+
auto c = GetRef<Call>(p);
337+
}
338+
319339
} // namespace relay
320340
} // namespace tvm

tests/cpp/relay_dismantler_test.cc

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
2019
#include <gtest/gtest.h>
2120
#include <tvm/ir/expr.h>
2221
#include <tvm/ir/type_functor.h>
@@ -38,6 +37,8 @@
3837
#include <tvm/topi/broadcast.h>
3938
#include <tvm/topi/generic/injective.h>
4039

40+
#include <memory>
41+
4142
using namespace tvm;
4243
using namespace tvm::relay;
4344

@@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) {
6970
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
7071
}
7172

73+
TEST(Relay, OutOfStack_packed_func) {
74+
constexpr int len = 1e6;
75+
auto foo = [] {
76+
auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
77+
auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
78+
auto add_func = tvm::runtime::Registry::Get("relay.op._make.add");
79+
auto y = (*add_func)(x, one);
80+
for (int i = 0; i < len; ++i) {
81+
y = (*add_func)(y, one);
82+
}
83+
84+
// check if still reachable
85+
int k = 0;
86+
Expr e = y;
87+
while (e.defined() && e.as<CallNode>() != nullptr) {
88+
e = e.as<CallNode>()->args[0];
89+
++k;
90+
}
91+
ASSERT_EQ(len + 1, k);
92+
};
93+
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
94+
}
95+
96+
TEST(Relay, CallNodeSharedArgs) {
97+
auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
98+
auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
99+
auto relu_op = relay::Op::Get("nn.relu");
100+
Call y = relay::Call(relu_op, {x}, Attrs(), {});
101+
y = relay::Call(relu_op, {y}, Attrs(), {});
102+
ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
103+
y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args);
104+
ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
105+
}
106+
107+
TEST(Relay, TupleSharedFields) {
108+
auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
109+
auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
110+
auto relu_op = relay::Op::Get("nn.relu");
111+
Expr y = relay::Call(relu_op, {x}, Attrs(), {});
112+
y = relay::Call(relu_op, {y}, Attrs(), {});
113+
{
114+
Expr y1 = relay::Tuple(y.as<CallNode>()->args);
115+
Expr y2 = relay::Tuple(y.as<CallNode>()->args);
116+
117+
y1 = relay::Call(relu_op, {y1});
118+
y2 = relay::Call(relu_op, {y2});
119+
y = y1;
120+
}
121+
ASSERT_EQ(1, y.as<CallNode>()->args[0].as<TupleNode>()->fields[0].as<CallNode>()->args.size());
122+
}
123+
124+
TEST(Relay, TupleiGetItemSharedTuple) {
125+
auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
126+
auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
127+
auto relu_op = relay::Op::Get("nn.relu");
128+
Expr y = relay::Call(relu_op, {x}, Attrs(), {});
129+
y = relay::Tuple({y});
130+
{
131+
Expr y1 = relay::TupleGetItem(y, 0);
132+
Expr y2 = relay::TupleGetItem(y, 0);
133+
134+
y1 = relay::Call(relu_op, {y1});
135+
y2 = relay::Call(relu_op, {y2});
136+
y = y1;
137+
}
138+
ASSERT_EQ(1, y.as<CallNode>()
139+
->args[0]
140+
.as<TupleGetItemNode>()
141+
->tuple.as<TupleNode>()
142+
->fields[0]
143+
.as<CallNode>()
144+
->args.size());
145+
}
146+
72147
int main(int argc, char** argv) {
73148
testing::InitGoogleTest(&argc, argv);
74149
testing::FLAGS_gtest_death_test_style = "threadsafe";

tests/python/relay/test_ir_text_printer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
def astext(program, unify_free_vars=False):
3232
text = program.astext()
33+
3334
print(text)
3435
if isinstance(program, Expr):
3536
roundtrip_program = tvm.parser.parse_expr(text)
@@ -47,6 +48,17 @@ def show(text):
4748
print(text)
4849

4950

51+
def test_large_graph():
52+
x = relay.var("x", shape=(3, 2))
53+
y = relay.var("y")
54+
one = relay.const(10e10, dtype="float32")
55+
z = relay.add(x, one)
56+
for i in range(int(1e6)):
57+
z = relay.add(z, one)
58+
f = relay.Function([x, y], z)
59+
show(astext(f))
60+
61+
5062
def test_func():
5163
x = relay.var("x", shape=(3, 2))
5264
y = relay.var("y")

0 commit comments

Comments
 (0)