Skip to content

Commit

Permalink
exposed cp params to DPA api
Browse files Browse the repository at this point in the history
Signed-off-by: Md Fahim Faysal Khan <[email protected]>
  • Loading branch information
kocchop committed Oct 27, 2024
1 parent 7fb22c3 commit c4d3749
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
3 changes: 0 additions & 3 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def test_self_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -268,7 +267,6 @@ def test_cross_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -438,7 +436,6 @@ def test_contex_parallel_self_attn(
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")

Expand Down
6 changes: 0 additions & 6 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
Expand All @@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False

return True


Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

@nn.compact
def __call__(
Expand Down Expand Up @@ -308,6 +310,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
Expand All @@ -331,6 +335,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
Expand All @@ -349,6 +355,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
Expand Down Expand Up @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Optimization parameters
-----------------------
Expand All @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

@nn.compact
def __call__(
Expand Down Expand Up @@ -614,6 +627,8 @@ def __call__(
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)

return x
Expand Down

0 comments on commit c4d3749

Please sign in to comment.