forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_static_runtime.py
207 lines (174 loc) · 7.24 KB
/
test_static_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import numpy as np
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
class StaticRuntime:
def __init__(self, scripted):
# this is an nn.Module
if hasattr(scripted, "_c"):
self.static_runtime = torch._C._jit_to_static_runtime(scripted._c)
else:
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
def __call__(self, *args, **kwargs):
if not kwargs:
return self.static_runtime.run(args)
else:
return self.static_runtime.run(args, kwargs)
def benchmark(self, args, kwargs, warmup_runs, main_runs):
self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs)
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
return self.static_runtime.benchmark_individual_ops(
args, kwargs, warmup_runs, main_runs
)
def linear_shim(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
torch.nn.functional.linear = linear_shim
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
# self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
# energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim=-1)
# x = torch.matmul(self.dropout(attention), V)
x = torch.matmul(attention, V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
def create_mlp(ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, len(ln) - 1):
n = ln[i]
m = ln[i + 1]
LL = nn.Linear(int(n), int(m), bias=True)
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
with torch.no_grad():
s = torch.jit.script(torch.nn.Sequential(*layers))
s.eval()
return s
def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
o_test = attention_a(src, src, src, src_mask)
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
for a, b in zip(o_ref, o_test_kw):
torch.testing.assert_allclose(a, b)
def test_multihead_attention_layer_benchmark(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention_a = StaticRuntime(attention)
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
metrics = attention_a.benchmark_individual_ops(
[src, src, src, src_mask], {}, 2, 2
)
def test_mlp(self):
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
for _ in range(5):
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
def test_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticRuntime(tg)
o_test = tg_a(s, s, s)[0]
torch.testing.assert_allclose(o_ref, o_test)
def test_leaky_relu(self):
s = torch.randn(5, 5)
tg = torch.jit.script(nn.LeakyReLU(0.1))
o_ref = tg(s)
tg_a = StaticRuntime(tg)
o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
run_tests()