-
Notifications
You must be signed in to change notification settings - Fork 2
/
transformers.patch
267 lines (249 loc) · 13 KB
/
transformers.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
diff --git a/src/transformers/activations.py b/src/transformers/activations.py
index d9caf8763..a3e3abd34 100644
--- a/src/transformers/activations.py
+++ b/src/transformers/activations.py
@@ -32,8 +32,7 @@ class NewGELUActivation(nn.Module):
"""
def forward(self, input: Tensor) -> Tensor:
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
-
+ return nn.functional.gelu(input, approximate='tanh')
class GELUActivation(nn.Module):
"""
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 7650276c5..7d14236c9 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -16,6 +16,8 @@
import copy
import inspect
+import re
+import time
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -683,6 +685,9 @@ def _expand_inputs_for_generation(
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
past_key_values = None
+ # To use torch.jit.trace, the output is no longer a Dict. outputs[1] corresponds to "past_key_values"
+ if self.jit == True:
+ past_key_values = outputs[1]
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
@@ -1175,6 +1180,10 @@ def generate(
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
+ self.jit = kwargs.pop("jit", False)
+ self.tp_number = kwargs.pop("TP_number", 1)
+ self.token_latency = kwargs.pop("token_latency", None)
+ self.use_tpp = kwargs.pop("use_tpp", False)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
@@ -2118,6 +2127,7 @@ def greedy_search(
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# init values
+ latency_list = []
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
@@ -2162,6 +2172,7 @@ def greedy_search(
this_peer_finished = False # used by synced_gpus only
while True:
+ tic = time.time()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
@@ -2229,6 +2240,7 @@ def greedy_search(
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
+ latency_list.append(time.time() - tic)
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
@@ -2237,7 +2249,7 @@ def greedy_search(
if return_dict_in_generate:
if self.config.is_encoder_decoder:
- return GreedySearchEncoderDecoderOutput(
+ output_result = GreedySearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
@@ -2247,14 +2259,19 @@ def greedy_search(
decoder_hidden_states=decoder_hidden_states,
)
else:
- return GreedySearchDecoderOnlyOutput(
+ output_result = GreedySearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
- return input_ids
+ output_result = input_ids
+
+ if self.token_latency is not None:
+ return (output_result, latency_list)
+ else:
+ return output_result
def sample(
self,
@@ -2645,6 +2662,7 @@ def beam_search(
['Wie alt bist du?']
```"""
# init values
+ latency_list = []
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
@@ -2707,6 +2725,7 @@ def beam_search(
this_peer_finished = False # used by synced_gpus only
while True:
+ tic = time.time()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
@@ -2718,19 +2737,70 @@ def beam_search(
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits = outputs.logits[:, -1, :]
+ if self.jit == False:
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue # don't waste resources running the code we don't need
+ next_token_logits = outputs.logits[:, -1, :]
+ else:
+ first_token = False
+ input_bs = input_ids.size()[0]
+ if model_inputs["past_key_values"] is None:
+ first_token = True
+ if first_token:
+ seq_len = input_ids.size()[1]
+ if self.use_tpp:
+ model_inputs["past_key_values"] = tuple([(torch.zeros([1,1,int(self.config.n_head/self.tp_number)*int(self.config.n_embd/self.config.n_head)]), torch.zeros([1,1,int(self.config.n_head/self.tp_number)*int(self.config.n_embd/self.config.n_head)])) for i in range(self.config.n_layer)])
+ else:
+ model_inputs["past_key_values"] = tuple([(torch.zeros([1,int(self.config.n_head/self.tp_number),1,int(self.config.n_embd/self.config.n_head)]), torch.zeros([1,int(self.config.n_head/self.tp_number),1,int(self.config.n_embd/self.config.n_head)])) for i in range(self.config.n_layer)])
+ model_inputs["attention_mask"] = model_inputs["attention_mask"][:1,:]
+ model_inputs["input_ids"] = model_inputs["input_ids"][:1,:]
+ model_inputs["position_ids"] = model_inputs["position_ids"][:1,:]
+ model_inputs["attention_mask"] = torch.cat([torch.zeros(1, 1), model_inputs["attention_mask"]], dim=-1)
+ else:
+ model_inputs["attention_mask"] = torch.cat([torch.zeros(input_bs, 1), model_inputs["attention_mask"]], dim=-1)
+ model_inputs.pop("use_cache", None)
+ model_inputs.pop("token_type_ids", None)
+
+ if not hasattr(self,"trace_graph") and self.jit:
+ example_inputs = []
+ for k, v in model_inputs.items():
+ if v is not None and not isinstance(v, bool):
+ example_inputs.append(v)
+ example_inputs = tuple(example_inputs)
+ self_jit = torch.jit.trace(self, example_inputs, strict=False)
+ self_jit = torch.jit.freeze(self_jit.eval())
+ setattr(self, "trace_graph", self_jit)
+ outputs = self.trace_graph(**model_inputs)
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue # don't waste resources running the code we don't need
+ if first_token:
+ outputs = list(outputs)
+ outputs[0] = outputs[0].expand(input_bs, -1, -1)
+ past_key_values = []
+ for key, value in outputs[1]:
+ key_dim = key.dim()
+ value_dim = value.dim()
+ key = key.expand(input_bs, -1, -1, -1).contiguous()
+ value = value.expand(input_bs, -1, -1, -1).contiguous()
+ if key_dim == 3:
+ key = key.view(key.size(1) * key.size(0), key.size(2), key.size(3))
+ if value_dim == 3:
+ value = value.view(value.size(1) * value.size(0), value.size(2), value.size(3))
+ past_key_values.append(tuple([key, value]))
+ outputs[1] = tuple(past_key_values)
+ outputs = tuple(outputs)
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue # don't waste resources running the code we don't need
+ next_token_logits = outputs[0][:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
@@ -2799,6 +2869,7 @@ def beam_search(
# increase cur_len
cur_len = cur_len + 1
+ latency_list.append(time.time() - tic)
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
@@ -2822,7 +2893,7 @@ def beam_search(
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
- return BeamSearchEncoderDecoderOutput(
+ output_result = BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
@@ -2834,7 +2905,7 @@ def beam_search(
decoder_hidden_states=decoder_hidden_states,
)
else:
- return BeamSearchDecoderOnlyOutput(
+ output_result = BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
@@ -2843,7 +2914,12 @@ def beam_search(
hidden_states=decoder_hidden_states,
)
else:
- return sequence_outputs["sequences"]
+ output_result = sequence_outputs["sequences"]
+ # result
+ if self.token_latency is not None:
+ return (output_result, latency_list)
+ else:
+ return output_result
def beam_sample(
self,
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index 84282fb07..c78245f3d 100755
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -77,7 +77,7 @@ def duplicate_interleave(m):
def apply_rotary_pos_emb(x, sincos, offset=0):
- sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos)
+ sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : torch.tensor(x.shape[1]) + torch.tensor(offset), None, :], sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
@@ -791,9 +791,9 @@ def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,