From 38460c3ab8d893d5f21d908ec3fd1ab4c88b8554 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 14 Dec 2024 01:22:59 -0800 Subject: [PATCH] revamp to get full mask Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 100 +++++++++++++++--------- 1 file changed, 64 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6cbd260ab5..15117b5a42 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1024,42 +1024,51 @@ def swap_key_value_dict(self, batch_indices): @torch.no_grad() -def get_swa_mask( - window_size: Tuple[int, int], +def get_full_mask( max_seqlen_q: int, max_seqlen_kv: int, attn_mask_type: str = "no_mask", - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, ) -> torch.Tensor: """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the - shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention. - For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the - top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding - is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix, - for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.:: + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: attn_mask_type output shape diagonal alignment -------------------------------------------------------------------------------------------- - no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right - causal [1, 1, max_seqlen_q, max_seqlen_kv] top left - causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right - padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on - actual sequence lengths - padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left - padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on - actual sequence lengths - arbitrary same as attention_mask bottom right + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] Parameters ---------- - window_size: Tuple[int, int] - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. max_seqlen_q: int Maximum sequence length for queries. max_seqlen_kv: int @@ -1067,16 +1076,30 @@ def get_swa_mask( attn_mask_type: str, default = `no_mask` Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` - Boolean tensor(s) used to mask out attention softmax input. + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". Returns ---------- - attn_mask_type: str, default = `no_mask` - New attention mask type "arbitrary". + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` attention_mask: torch.Tensor - Result after combining input mask and sliding window mask. + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` actual_seqlens_q: torch.Tensor For padding masks, the actual sequence lengths for queries, in shape [batch_size]. For other masks, `None`. @@ -1101,7 +1124,7 @@ def get_swa_mask( actual_seqlens_q = None actual_seqlens_kv = None if "padding" in attn_mask_type: - if max_seqlen_q == max_seqlen_kv: + if attention_type == "self": attention_mask = torch.logical_or( attention_mask.squeeze(1).unsqueeze(3), attention_mask ) @@ -1119,13 +1142,16 @@ def get_swa_mask( ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) swa_left = None swa_right = None - if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]: + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment): swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] - elif attn_mask_type in ["causal", "padding_causal"]: + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment): swa_left = mask - window_size[0] swa_right = mask + window_size[1] - elif attn_mask_type in ["padding", "padding_causal_bottom_right"]: + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment): batch_size = attention_mask.shape[0] swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( actual_seqlens_kv - actual_seqlens_q - window_size[0] @@ -4821,8 +4847,10 @@ def forward( key_layer.shape[0], ) - attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( + max_seqlen_q, max_seqlen_kv, attn_mask_type=attn_mask_type, + attention_mask=attention_mask, window_size=window_size, + attention_type=self.attention_type, ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]