|
16 | 16 | * specific language governing permissions and limitations |
17 | 17 | * under the License. |
18 | 18 | */ |
19 | | - |
20 | 19 | #include <gtest/gtest.h> |
21 | 20 | #include <tvm/ir/expr.h> |
22 | 21 | #include <tvm/ir/type_functor.h> |
|
38 | 37 | #include <tvm/topi/broadcast.h> |
39 | 38 | #include <tvm/topi/generic/injective.h> |
40 | 39 |
|
| 40 | +#include <memory> |
| 41 | + |
41 | 42 | using namespace tvm; |
42 | 43 | using namespace tvm::relay; |
43 | 44 |
|
@@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) { |
69 | 70 | ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); |
70 | 71 | } |
71 | 72 |
|
| 73 | +TEST(Relay, OutOfStack_packed_func) { |
| 74 | + constexpr int len = 1e6; |
| 75 | + auto foo = [] { |
| 76 | + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); |
| 77 | + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); |
| 78 | + auto add_func = tvm::runtime::Registry::Get("relay.op._make.add"); |
| 79 | + auto y = (*add_func)(x, one); |
| 80 | + for (int i = 0; i < len; ++i) { |
| 81 | + y = (*add_func)(y, one); |
| 82 | + } |
| 83 | + |
| 84 | + // check if still reachable |
| 85 | + int k = 0; |
| 86 | + Expr e = y; |
| 87 | + while (e.defined() && e.as<CallNode>() != nullptr) { |
| 88 | + e = e.as<CallNode>()->args[0]; |
| 89 | + ++k; |
| 90 | + } |
| 91 | + ASSERT_EQ(len + 1, k); |
| 92 | + }; |
| 93 | + ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); |
| 94 | +} |
| 95 | + |
| 96 | +TEST(Relay, CallNodeSharedArgs) { |
| 97 | + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); |
| 98 | + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); |
| 99 | + auto relu_op = relay::Op::Get("nn.relu"); |
| 100 | + Call y = relay::Call(relu_op, {x}, Attrs(), {}); |
| 101 | + y = relay::Call(relu_op, {y}, Attrs(), {}); |
| 102 | + ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size()); |
| 103 | + y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args); |
| 104 | + ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size()); |
| 105 | +} |
| 106 | + |
| 107 | +TEST(Relay, TupleSharedFields) { |
| 108 | + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); |
| 109 | + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); |
| 110 | + auto relu_op = relay::Op::Get("nn.relu"); |
| 111 | + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); |
| 112 | + y = relay::Call(relu_op, {y}, Attrs(), {}); |
| 113 | + { |
| 114 | + Expr y1 = relay::Tuple(y.as<CallNode>()->args); |
| 115 | + Expr y2 = relay::Tuple(y.as<CallNode>()->args); |
| 116 | + |
| 117 | + y1 = relay::Call(relu_op, {y1}); |
| 118 | + y2 = relay::Call(relu_op, {y2}); |
| 119 | + y = y1; |
| 120 | + } |
| 121 | + ASSERT_EQ(1, y.as<CallNode>()->args[0].as<TupleNode>()->fields[0].as<CallNode>()->args.size()); |
| 122 | +} |
| 123 | + |
| 124 | +TEST(Relay, TupleiGetItemSharedTuple) { |
| 125 | + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); |
| 126 | + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); |
| 127 | + auto relu_op = relay::Op::Get("nn.relu"); |
| 128 | + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); |
| 129 | + y = relay::Tuple({y}); |
| 130 | + { |
| 131 | + Expr y1 = relay::TupleGetItem(y, 0); |
| 132 | + Expr y2 = relay::TupleGetItem(y, 0); |
| 133 | + |
| 134 | + y1 = relay::Call(relu_op, {y1}); |
| 135 | + y2 = relay::Call(relu_op, {y2}); |
| 136 | + y = y1; |
| 137 | + } |
| 138 | + ASSERT_EQ(1, y.as<CallNode>() |
| 139 | + ->args[0] |
| 140 | + .as<TupleGetItemNode>() |
| 141 | + ->tuple.as<TupleNode>() |
| 142 | + ->fields[0] |
| 143 | + .as<CallNode>() |
| 144 | + ->args.size()); |
| 145 | +} |
| 146 | + |
72 | 147 | int main(int argc, char** argv) { |
73 | 148 | testing::InitGoogleTest(&argc, argv); |
74 | 149 | testing::FLAGS_gtest_death_test_style = "threadsafe"; |
|
0 commit comments