Skip to content

Commit

Permalink
Clean up checks in unit test.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Oct 30, 2024
1 parent c4d3749 commit b0c5c06
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
Expand All @@ -435,9 +436,15 @@ def test_contex_parallel_self_attn(
seqlen,
seqlen,
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
None) # no SWA for CP

# For causal masking we depend on having bottom right support also.
has_backend = check_has_backend_for_mask(attn_mask_type)
if mask == AttnMaskType.CAUSAL_MASK_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand Down

0 comments on commit b0c5c06

Please sign in to comment.