Skip to content

Commit f05f12c

Browse files
yuzhongw-nvidiacyanguwa
authored andcommitted
Fix MLA CP Bugs (#1896)
* fix: (1) UT ignores MLA; (2) bshd format runtime error. Ban fp8 mla attn + cp due to correctness problem Signed-off-by: Yuzhong Wang <[email protected]> * only disable FP8 CP for MLA Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Yuzhong Wang <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: Charlene Yang <[email protected]>
1 parent 8382eed commit f05f12c

File tree

4 files changed

+38
-13
lines changed

4 files changed

+38
-13
lines changed

tests/pytorch/fused_attn/run_fused_attn_with_cp.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def run_dpa_with_cp(
8989
# instantiate core attn module
9090
core_attn = DotProductAttention(
9191
config.num_heads,
92-
config.head_dim_qk,
92+
(config.head_dim_qk, config.head_dim_v),
9393
num_gqa_groups=config.num_gqa_groups,
9494
attention_dropout=config.dropout_p,
9595
qkv_format=qkv_format,
@@ -106,16 +106,22 @@ def run_dpa_with_cp(
106106
config.num_heads,
107107
config.head_dim_qk,
108108
)
109-
kv_input_shape = (
109+
k_input_shape = (
110110
config.batch_size,
111111
config.max_seqlen_kv,
112112
config.num_gqa_groups,
113113
config.head_dim_qk,
114114
)
115+
v_input_shape = (
116+
config.batch_size,
117+
config.max_seqlen_kv,
118+
config.num_gqa_groups,
119+
config.head_dim_v,
120+
)
115121
attn_output_shape = (
116122
config.batch_size,
117123
config.max_seqlen_q,
118-
config.num_heads * config.head_dim_qk,
124+
config.num_heads * config.head_dim_v,
119125
)
120126
cu_seqlens_q = None
121127
cu_seqlens_kv = None
@@ -128,16 +134,22 @@ def run_dpa_with_cp(
128134
config.num_heads,
129135
config.head_dim_qk,
130136
)
131-
kv_input_shape = (
137+
k_input_shape = (
132138
config.max_seqlen_kv,
133139
config.batch_size,
134140
config.num_gqa_groups,
135141
config.head_dim_qk,
136142
)
143+
v_input_shape = (
144+
config.max_seqlen_kv,
145+
config.batch_size,
146+
config.num_gqa_groups,
147+
config.head_dim_v,
148+
)
137149
attn_output_shape = (
138150
config.max_seqlen_q,
139151
config.batch_size,
140-
config.num_heads * config.head_dim_qk,
152+
config.num_heads * config.head_dim_v,
141153
)
142154
cu_seqlens_q = None
143155
cu_seqlens_kv = None
@@ -149,14 +161,19 @@ def run_dpa_with_cp(
149161
config.num_heads,
150162
config.head_dim_qk,
151163
)
152-
kv_input_shape = (
164+
k_input_shape = (
153165
config.batch_size * config.max_seqlen_q,
154166
config.num_gqa_groups,
155167
config.head_dim_qk,
156168
)
169+
v_input_shape = (
170+
config.batch_size * config.max_seqlen_q,
171+
config.num_gqa_groups,
172+
config.head_dim_v,
173+
)
157174
attn_output_shape = (
158175
config.batch_size * config.max_seqlen_q,
159-
config.num_heads * config.head_dim_qk,
176+
config.num_heads * config.head_dim_v,
160177
)
161178
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
162179
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
@@ -177,8 +194,8 @@ def run_dpa_with_cp(
177194
assert False, f"{qkv_format} is an unsupported qkv_format!"
178195

179196
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
180-
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
181-
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
197+
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
198+
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
182199
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
183200
dout_quantizer = Float8Quantizer(
184201
fp8_dtype=tex.DType.kFloat8E5M2,

tests/pytorch/fused_attn/test_fused_attn_with_cp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
173173
pytest.skip("Only fp8 works with fp8_mha=True!")
174174
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
175175
pytest.skip("MLA CP currently only support KV P2P!")
176+
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
177+
pytest.skip("MLA CP currently does not support FP8 attention!")
176178

177179
subprocess.run(
178180
get_bash_arguments(

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,8 +2559,8 @@ def backward(ctx, dout):
25592559

25602560
if ctx.enable_mla:
25612561
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
2562-
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
2563-
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
2562+
dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape)
2563+
dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape)
25642564
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
25652565
dk_fp8, fake_dtype=torch.float32, internal=True
25662566
)
@@ -2586,8 +2586,8 @@ def backward(ctx, dout):
25862586
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
25872587
if ctx.enable_mla:
25882588
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
2589-
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
2590-
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
2589+
dk = dk.view(dk.shape[0], -1, *dk.shape[-2:])
2590+
dv = dv.view(dv.shape[0], -1, *dv.shape[-2:])
25912591
else:
25922592
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
25932593
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,12 @@ def get_attention_backend(
608608
" bias for THD format"
609609
)
610610
use_fused_attention = False
611+
elif fp8 and head_dim_qk != head_dim_v:
612+
logger.debug(
613+
"Disabling FusedAttention as it does not support context parallelism with FP8"
614+
" MLA attention"
615+
)
616+
use_fused_attention = False
611617

612618
# Filter: Attention mask
613619
# attn_mask_type | attention_mask | supported backends

0 commit comments

Comments
 (0)