Skip to content

Commit d4e9c9e

Browse files
author
TheMrYang
committed
prompt
1 parent 3b4952a commit d4e9c9e

File tree

7 files changed

+642
-29
lines changed

7 files changed

+642
-29
lines changed

models/ofa/unify_multihead_attention.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def forward(
127127
self_attn_mask: Optional[Tensor] = None,
128128
before_softmax: bool = False,
129129
need_head_weights: bool = False,
130-
attn_bias: Optional[Tensor] = None
130+
attn_bias: Optional[Tensor] = None,
131+
prompt_kv: Optional[Tensor] = None
131132
) -> Tuple[Tensor, Optional[Tensor]]:
132133
"""Input shape: Time x Batch x Channel
133134
@@ -314,7 +315,7 @@ def forward(
314315

315316
if key_padding_mask is not None:
316317
assert key_padding_mask.size(0) == bsz
317-
assert key_padding_mask.size(1) == src_len
318+
assert key_padding_mask.size(1) == k.size(1)
318319

319320
if self.add_zero_attn:
320321
assert v is not None
@@ -335,14 +336,19 @@ def forward(
335336
],
336337
dim=1,
337338
)
338-
339+
if prompt_kv is not None:
340+
prompt_k, prompt_v = prompt_kv.split(1)
341+
prompt_k = prompt_k.squeeze(0).reshape(k.size(0), -1, k.size(2))
342+
prompt_v = prompt_v.squeeze(0).reshape(v.size(0), -1, v.size(2))
343+
k = torch.cat([prompt_k, k], dim=1)
344+
v = torch.cat([prompt_v, v], dim=1)
339345
attn_weights = torch.bmm(q, k.transpose(1, 2))
340-
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
346+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, k.size(1), bsz)
341347

342-
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
348+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, k.size(1)]
343349

344350
if attn_bias is not None:
345-
attn_weights += attn_bias
351+
attn_weights[:, :, -src_len:] += attn_bias[:, :, -src_len:]
346352

347353
if attn_mask is not None:
348354
attn_mask = attn_mask.unsqueeze(0)
@@ -351,12 +357,12 @@ def forward(
351357
attn_weights += attn_mask
352358

353359
if self_attn_mask is not None:
354-
self_attn_mask = self_attn_mask.unsqueeze(1).expand(bsz, self.num_heads, tgt_len, src_len)
355-
attn_weights += self_attn_mask.contiguous().view(bsz * self.num_heads, tgt_len, src_len)
360+
self_attn_mask = self_attn_mask.unsqueeze(1).expand(bsz, self.num_heads, tgt_len, k.size(1))
361+
attn_weights += self_attn_mask.contiguous().view(bsz * self.num_heads, tgt_len, k.size(1))
356362

357363
if key_padding_mask is not None:
358364
# don't attend to padding symbols
359-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
365+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, k.size(1))
360366
if not is_tpu:
361367
attn_weights = attn_weights.masked_fill(
362368
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
@@ -366,7 +372,7 @@ def forward(
366372
attn_weights = attn_weights.transpose(0, 2)
367373
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
368374
attn_weights = attn_weights.transpose(0, 2)
369-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
375+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, k.size(1))
370376

371377
if before_softmax:
372378
return attn_weights, v
@@ -394,7 +400,7 @@ def forward(
394400
attn_weights: Optional[Tensor] = None
395401
if need_weights:
396402
attn_weights = attn_weights_float.view(
397-
bsz, self.num_heads, tgt_len, src_len
403+
bsz, self.num_heads, tgt_len, k.size(1)
398404
).transpose(1, 0)
399405
if not need_head_weights:
400406
# average attention weights over heads

0 commit comments

Comments
 (0)