Skip to content

Commit

Permalink
revamp to get full mask
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 14, 2024
1 parent 3ed8322 commit 38460c3
Showing 1 changed file with 64 additions and 36 deletions.
100 changes: 64 additions & 36 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,59 +1024,82 @@ def swap_key_value_dict(self, batch_indices):


@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
def get_full_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the
shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention.
For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the
top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding
is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix,
for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.::
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
causal [1, 1, max_seqlen_q, max_seqlen_kv] top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
actual sequence lengths
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
actual sequence lengths
arbitrary same as attention_mask bottom right
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment
.. note::
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::
[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]
Parameters
----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input.
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str, default = `no_mask`
New attention mask type "arbitrary".
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
Result after combining input mask and sliding window mask.
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
Expand All @@ -1101,7 +1124,7 @@ def get_swa_mask(
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if max_seqlen_q == max_seqlen_kv:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
Expand All @@ -1119,13 +1142,16 @@ def get_swa_mask(
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]:
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"]:
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type in ["padding", "padding_causal_bottom_right"]:
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
Expand Down Expand Up @@ -4821,8 +4847,10 @@ def forward(
key_layer.shape[0],
)

attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
max_seqlen_q, max_seqlen_kv, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, window_size=window_size,
attention_type=self.attention_type,
)

batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
Expand Down

0 comments on commit 38460c3

Please sign in to comment.