From c4d37494483926e830aa16e38cafbb5cf3538819 Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Sun, 27 Oct 2024 01:22:22 -0700 Subject: [PATCH] exposed cp params to DPA api Signed-off-by: Md Fahim Faysal Khan --- tests/jax/test_distributed_fused_attn.py | 3 --- transformer_engine/jax/attention.py | 6 ------ transformer_engine/jax/flax/transformer.py | 15 +++++++++++++++ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 23a26087d4..d34afb340c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -133,7 +133,6 @@ def test_self_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -268,7 +267,6 @@ def test_cross_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -438,7 +436,6 @@ def test_contex_parallel_self_attn( seqlen, hidden, None, # no window - cp_size > 1, ): pytest.skip(f"No FusedAttn backend found") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b3b11bb9dd..218fd174de 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -190,7 +190,6 @@ def is_fused_attn_kernel_available( kv_max_seqlen, head_dim, window_size: Optional[Tuple[int, int]] = None, - is_context_parallel: bool = False, ): """ To check whether the fused attention kernel is supported @@ -215,11 +214,6 @@ def make_helper(attn_mask_type): if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): return False - # For context parallel need to check additional masking types - if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK: - if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available(): - return False - return True diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index b91584219f..cb71188221 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me scale_factor: Optional[float] = None transpose_batch_sequence: bool = False window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -308,6 +310,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat @@ -331,6 +335,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: @@ -349,6 +355,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None Sliding window size. The default value is no sliding window. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Optimization parameters ----------------------- @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods scale_factor: Optional[float] = None transpose_batch_sequence: bool = True window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -614,6 +627,8 @@ def __call__( transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) return x