@@ -127,7 +127,8 @@ def forward(
127127 self_attn_mask : Optional [Tensor ] = None ,
128128 before_softmax : bool = False ,
129129 need_head_weights : bool = False ,
130- attn_bias : Optional [Tensor ] = None
130+ attn_bias : Optional [Tensor ] = None ,
131+ prompt_kv : Optional [Tensor ] = None
131132 ) -> Tuple [Tensor , Optional [Tensor ]]:
132133 """Input shape: Time x Batch x Channel
133134
@@ -314,7 +315,7 @@ def forward(
314315
315316 if key_padding_mask is not None :
316317 assert key_padding_mask .size (0 ) == bsz
317- assert key_padding_mask .size (1 ) == src_len
318+ assert key_padding_mask .size (1 ) == k . size ( 1 )
318319
319320 if self .add_zero_attn :
320321 assert v is not None
@@ -335,14 +336,19 @@ def forward(
335336 ],
336337 dim = 1 ,
337338 )
338-
339+ if prompt_kv is not None :
340+ prompt_k , prompt_v = prompt_kv .split (1 )
341+ prompt_k = prompt_k .squeeze (0 ).reshape (k .size (0 ), - 1 , k .size (2 ))
342+ prompt_v = prompt_v .squeeze (0 ).reshape (v .size (0 ), - 1 , v .size (2 ))
343+ k = torch .cat ([prompt_k , k ], dim = 1 )
344+ v = torch .cat ([prompt_v , v ], dim = 1 )
339345 attn_weights = torch .bmm (q , k .transpose (1 , 2 ))
340- attn_weights = self .apply_sparse_mask (attn_weights , tgt_len , src_len , bsz )
346+ attn_weights = self .apply_sparse_mask (attn_weights , tgt_len , k . size ( 1 ) , bsz )
341347
342- assert list (attn_weights .size ()) == [bsz * self .num_heads , tgt_len , src_len ]
348+ assert list (attn_weights .size ()) == [bsz * self .num_heads , tgt_len , k . size ( 1 ) ]
343349
344350 if attn_bias is not None :
345- attn_weights += attn_bias
351+ attn_weights [:, :, - src_len :] += attn_bias [:, :, - src_len :]
346352
347353 if attn_mask is not None :
348354 attn_mask = attn_mask .unsqueeze (0 )
@@ -351,12 +357,12 @@ def forward(
351357 attn_weights += attn_mask
352358
353359 if self_attn_mask is not None :
354- self_attn_mask = self_attn_mask .unsqueeze (1 ).expand (bsz , self .num_heads , tgt_len , src_len )
355- attn_weights += self_attn_mask .contiguous ().view (bsz * self .num_heads , tgt_len , src_len )
360+ self_attn_mask = self_attn_mask .unsqueeze (1 ).expand (bsz , self .num_heads , tgt_len , k . size ( 1 ) )
361+ attn_weights += self_attn_mask .contiguous ().view (bsz * self .num_heads , tgt_len , k . size ( 1 ) )
356362
357363 if key_padding_mask is not None :
358364 # don't attend to padding symbols
359- attn_weights = attn_weights .view (bsz , self .num_heads , tgt_len , src_len )
365+ attn_weights = attn_weights .view (bsz , self .num_heads , tgt_len , k . size ( 1 ) )
360366 if not is_tpu :
361367 attn_weights = attn_weights .masked_fill (
362368 key_padding_mask .unsqueeze (1 ).unsqueeze (2 ).to (torch .bool ),
@@ -366,7 +372,7 @@ def forward(
366372 attn_weights = attn_weights .transpose (0 , 2 )
367373 attn_weights = attn_weights .masked_fill (key_padding_mask , float ("-inf" ))
368374 attn_weights = attn_weights .transpose (0 , 2 )
369- attn_weights = attn_weights .view (bsz * self .num_heads , tgt_len , src_len )
375+ attn_weights = attn_weights .view (bsz * self .num_heads , tgt_len , k . size ( 1 ) )
370376
371377 if before_softmax :
372378 return attn_weights , v
@@ -394,7 +400,7 @@ def forward(
394400 attn_weights : Optional [Tensor ] = None
395401 if need_weights :
396402 attn_weights = attn_weights_float .view (
397- bsz , self .num_heads , tgt_len , src_len
403+ bsz , self .num_heads , tgt_len , k . size ( 1 )
398404 ).transpose (1 , 0 )
399405 if not need_head_weights :
400406 # average attention weights over heads
0 commit comments