Skip to content

Commit

Permalink
[common/PyTorch] Add cuDNN SWA (left, 0) + padding + bottom right cau…
Browse files Browse the repository at this point in the history
…sal (#1378)

* add swa (left,0) + padding + brcm support

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* final fixes

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade to FE 1.9-rc

Signed-off-by: Charlene Yang <[email protected]>

* fix jax tests

Signed-off-by: Charlene Yang <[email protected]>

* skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stats shapes

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Dec 20, 2024
1 parent a3b32ec commit 838345e
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 101 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 43 files
+1 −1 CMakeLists.txt
+10 −0 docs/operations/Attention.md
+3 −2 include/cudnn_backend_base.h
+1 −0 include/cudnn_frontend.h
+24 −2 include/cudnn_frontend/graph_helpers.h
+28 −0 include/cudnn_frontend/graph_interface.h
+32 −1 include/cudnn_frontend/graph_properties.h
+6 −0 include/cudnn_frontend/node/paged_cache_load.h
+3 −0 include/cudnn_frontend/node/resample.h
+372 −481 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+4 −1 include/cudnn_frontend/node/sdpa_fp8.h
+5 −1 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −3 include/cudnn_frontend/plans.h
+387 −0 include/cudnn_frontend/utils/attn_score_modifiers.h
+3 −3 include/cudnn_frontend_EngineFallbackList.h
+3 −3 include/cudnn_frontend_ExecutionPlan.h
+3 −4 include/cudnn_frontend_Operation.h
+1 −1 include/cudnn_frontend_OperationGraph.h
+3 −4 include/cudnn_frontend_get_plan.h
+2 −0 include/cudnn_frontend_shim.h
+1 −1 include/cudnn_frontend_utils.h
+1 −1 include/cudnn_frontend_version.h
+2 −2 pyproject.toml
+1 −1 python/cudnn/__init__.py
+16 −0 python/pygraph/pygraph.cpp
+3 −0 python/pygraph/pygraph.h
+2 −2 python/pygraph/sdpa.cpp
+3 −0 samples/cpp/CMakeLists.txt
+205 −0 samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp
+2 −1 samples/cpp/convolution/fp8_fprop.cpp
+4 −0 samples/cpp/convolution/fprop.cpp
+5 −1 samples/cpp/convolution/wgrads.cpp
+144 −0 samples/cpp/norm/layernorm.cpp
+207 −0 samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp
+198 −0 samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp
+1 −1 samples/cpp/utils/helpers.h
+5 −3 samples/legacy_samples/fp16_emu.cpp
+1 −1 samples/legacy_samples/helpers.cpp
+5 −0 samples/legacy_samples/test_list.cpp
+3 −1 samples/python/50_scaled_dot_product_attention.ipynb
+5 −3 samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb
+7 −0 test/python/test_conv_bias.py
+112 −60 test/python/test_mhas.py
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
Expand All @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
18 changes: 15 additions & 3 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def make_mask(
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

mask = jnp.logical_not(inv_mask)
return mask
Expand Down Expand Up @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -504,7 +510,13 @@ def generate_random_segment_ids(
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.mask_for_customcall = self.mask
self.mask_for_customcall = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)

self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
Expand Down
186 changes: 124 additions & 62 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,18 @@ def test_dot_product_attention(
tols = dict(atol=1.5e-2, rtol=1.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
is_mqa_gqa = config.num_heads != config.num_gqa_groups
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
else:
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")

# Test backend availability
window_size = (-1, -1)
if swa:
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
Expand Down Expand Up @@ -334,16 +333,16 @@ def test_dot_product_attention(
is_training,
)

if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
Expand Down Expand Up @@ -399,30 +398,41 @@ def test_dpa_mla(dtype, model_configs, model):

model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
}


Expand Down Expand Up @@ -531,20 +541,28 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_0": ModelConfig(
4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
Expand Down Expand Up @@ -623,18 +641,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_2_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_2": ModelConfig(
2,
24,
24,
128,
2048,
4096,
0.0,
"padding_causal_bottom_right",
"no_bias",
window_size=(4, 0),
),
}


Expand All @@ -651,11 +708,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
Expand Down Expand Up @@ -695,9 +754,12 @@ def _run_dot_product_attention(
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
Expand Down
Loading

0 comments on commit 838345e

Please sign in to comment.