Skip to content

Commit fedd9dd

Browse files
cyanguwapre-commit-ci[bot]
authored andcommitted
[PyTorch] Disable determinism for sm100 (#2130)
* disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9cd6d16 commit fedd9dd

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

tests/pytorch/test_numerics.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,18 @@
111111

112112

113113
def is_fused_attn_available(
114-
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
114+
config: ModelConfig,
115+
dtype: torch.dtype,
116+
qkv_layout="bshd_bshd_bshd",
117+
is_training=True,
118+
deterministic=False,
115119
):
116120
_, _, fused_attn_backends = get_available_attention_backends(
117121
config,
118122
qkv_dtype=dtype,
119123
qkv_layout=qkv_layout,
120124
is_training=is_training,
125+
deterministic=deterministic,
121126
)
122127
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
123128

@@ -825,7 +830,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
825830
@pytest.mark.parametrize("model", ["126m"])
826831
def test_gpt_checkpointing(dtype, bs, model):
827832
config = model_configs[model]
828-
if not is_fused_attn_available(config, dtype):
833+
if not is_fused_attn_available(config, dtype, deterministic=True):
829834
pytest.skip("No attention backend available.")
830835
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
831836
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
@@ -873,7 +878,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
873878
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
874879
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
875880
config = model_configs[model]
876-
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
881+
if not is_fused_attn_available(
882+
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
883+
):
877884
pytest.skip("No attention backend available.")
878885

879886
te_gpt = TransformerLayer(
@@ -986,7 +993,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
986993
@pytest.mark.parametrize("mask_type", mask_types)
987994
def test_mha_accuracy(dtype, bs, model, mask_type):
988995
config = model_configs[model]
989-
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
996+
if not is_fused_attn_available(
997+
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
998+
):
990999
pytest.skip("No attention backend available.")
9911000

9921001
te_mha = MultiheadAttention(

tests/pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def test():
266266
)
267267
(
268268
use_flash_attention,
269-
use_fused_attention,
270269
flash_attention_backend,
270+
use_fused_attention,
271271
fused_attention_backend,
272272
use_unfused_attention,
273273
available_backends,

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def get_attention_backend(
822822
# flash-attn >=2.4.1 | yes
823823
# FusedAttention |
824824
# sub-backend 0 | yes
825-
# sub-backend 1 | workspace optimization path and sm90+: yes;
825+
# sub-backend 1 | workspace optimization path and sm90: yes;
826826
# | otherwise: no
827827
# sub-backend 2 | no
828828
# UnfusedDotProductAttention | yes
@@ -838,8 +838,9 @@ def get_attention_backend(
838838
use_flash_attention_2 = False
839839
if use_fused_attention and deterministic:
840840
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
841-
logger.debug("Disabling FusedAttention for determinism reasons")
841+
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
842842
use_fused_attention = False
843+
fused_attention_backend = None
843844
if (
844845
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
845846
and is_training
@@ -849,8 +850,13 @@ def get_attention_backend(
849850
or cudnn_version < (8, 9, 5)
850851
)
851852
):
852-
logger.debug("Disabling FusedAttention for determinism reasons")
853+
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
854+
use_fused_attention = False
855+
fused_attention_backend = None
856+
if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0):
857+
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
853858
use_fused_attention = False
859+
fused_attention_backend = None
854860

855861
# use_flash_attention may have been set above
856862
use_flash_attention_2 = use_flash_attention and use_flash_attention_2

0 commit comments

Comments
 (0)