-
Notifications
You must be signed in to change notification settings - Fork 17
/
dynamic_test.py
139 lines (113 loc) · 3.39 KB
/
dynamic_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import torch
import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import from_pytorch
class SimpleIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
return output
class NestedIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
if inp.mean() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
else:
if inp.mean() > 0.:
output = self.weight * inp
else:
output = self.weight / inp
return output
class ScalarLoop(torch.nn.Module):
def forward(self, inp):
a = 0
for i in range(inp.size(0)):
b = i * i
b = b + 1
a += b
return a
class SimpleLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
c = a + b
a += c
return a
class LoopWithIf(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
b = a + b
if b.sum() > 0.0:
a += b
else:
a -= b
return a
class NestedLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * float(i)
for j in range(inp.size(1)):
a += b * float(j)
return a
class SimpleScalarWhileLoop(torch.nn.Module):
def forward(self, inp):
a = 1
i = 0
while i < inp.size(0):
a += i
i += 2
return a
class SimpleWhileLoop(torch.nn.Module):
def forward(self, inp):
a = inp
i = 0
while i < inp.size(0):
a += a * float(i) * 2.0
i += 1
return a
models = [
SimpleIf(10, 20).eval(),
NestedIf(10, 20).eval(),
ScalarLoop().eval(),
SimpleLoop().eval(),
LoopWithIf().eval(),
SimpleScalarWhileLoop().eval(),
SimpleWhileLoop().eval(),
NestedLoop().eval()
]
for raw_model in models:
script_module = torch.jit.script(raw_model)
input_name = "input"
input_shapes = [(input_name, (10, 20))]
mod, params = from_pytorch(script_module, input_shapes)
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
evaluator = executor.evaluate()
for i in range(5):
inp = torch.rand(input_shapes[0][1], dtype=torch.float)
with torch.no_grad():
pt_result = raw_model(inp.clone())
params[input_name] = inp.numpy()
op_res = evaluator(**params)
if not isinstance(pt_result, torch.Tensor):
tvm_res = np.asscalar(op_res.asnumpy())
print(abs(pt_result - tvm_res))
assert pt_result == tvm_res
else:
print(np.max(np.abs(op_res.asnumpy() - pt_result.numpy())))
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)