Skip to content

Commit bd7fd0a

Browse files
authored
[Paddle] Support GQA (#595)
* use separate qkv Signed-off-by: jaywan <[email protected]> * add support for GQA Signed-off-by: jaywan <[email protected]> * minor changes Signed-off-by: Shijie Wang <[email protected]> * change rtol Signed-off-by: Shijie Wang <[email protected]> * fix reshape issue Signed-off-by: Shijie Wang <[email protected]> --------- Signed-off-by: jaywan <[email protected]> Signed-off-by: Shijie Wang <[email protected]>
1 parent e531cd2 commit bd7fd0a

File tree

6 files changed

+751
-168
lines changed

6 files changed

+751
-168
lines changed

tests/paddle/test_layers.py

+57-83
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# See LICENSE for license information.
44
"""Test TE Paddle Layer-level APIs"""
55

6-
import math
76
import os
87
from utils import assert_allclose, is_fused_attention_supported
98

@@ -785,7 +784,7 @@ def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activati
785784

786785
@pytest.mark.parametrize('bs', [1, 2, 8])
787786
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
788-
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
787+
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
789788
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
790789
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
791790
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
808807
head_size=head_size,
809808
dtype=math_dtype,
810809
dropout=0.0,
811-
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
810+
qkv_layout="bshd_bshd_bshd",
812811
bias_type="no_bias",
813812
mask_type=mask_type,
814813
):
815814
pytest.skip("cuDNN fused attention is not supported")
816815

817-
self_attn_qkv_input = paddle.normal(mean=0.0,
818-
std=0.02,
819-
shape=(bs, q_seqlen, 3, num_heads,
820-
head_size)).astype(math_dtype)
821-
cross_attn_q_input = paddle.normal(mean=0.0,
822-
std=0.02,
823-
shape=(bs, q_seqlen, num_heads,
824-
head_size)).astype(math_dtype)
825-
cross_attn_kv_input = paddle.normal(mean=0.0,
826-
std=0.02,
827-
shape=(bs, kv_seqlen, 2, num_heads,
828-
head_size)).astype(math_dtype)
816+
attn_q_input = paddle.normal(mean=0.0, std=0.02,
817+
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
818+
attn_k_input = paddle.normal(mean=0.0, std=0.02,
819+
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
820+
attn_v_input = paddle.normal(mean=0.0, std=0.02,
821+
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
829822

830823
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
831824
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
@@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
841834
for i in range(0, bs):
842835
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
843836

844-
norm_factor = math.sqrt(hidden_size // num_heads)
845-
layer_te = te.DotProductAttention(norm_factor,
837+
head_size = hidden_size // num_heads
838+
layer_te = te.DotProductAttention(num_heads,
839+
head_size,
846840
attention_dropout=0.0,
847841
attn_mask_type=mask_type,
848842
attention_type=attn_type,
849843
backend='transformer_engine')
850-
layer_pd = te.DotProductAttention(norm_factor,
844+
layer_pd = te.DotProductAttention(num_heads,
845+
head_size,
851846
attention_dropout=0.0,
852847
attn_mask_type=mask_type,
853848
attention_type=attn_type,
854849
backend='paddle')
855850

856-
def calc_attn_output_and_grad(layer, q, kv, mask, dout):
851+
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
857852
_q = paddle.to_tensor(q, stop_gradient=False)
858-
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None
853+
_k = paddle.to_tensor(k, stop_gradient=False)
854+
_v = paddle.to_tensor(v, stop_gradient=False)
859855

860-
out = layer(_q, _kv, mask)
856+
out = layer(_q, _k, _v, mask)
861857
out.backward(dout)
862-
return out, _q.grad, _kv.grad if _kv is not None else None
863-
864-
if attn_type == 'self':
865-
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask,
866-
grad_out)
867-
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None,
868-
attn_mask, grad_out)
869-
valid_out_ref = paddle.full_like(out_ref, 0)
870-
for i in range(0, bs):
871-
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
872-
873-
q_grad = qkv_grad[:, :, 0]
874-
k_grad = qkv_grad[:, :, 1]
875-
v_grad = qkv_grad[:, :, 2]
876-
q_grad_ref = qkv_grad_ref[:, :, 0]
877-
k_grad_ref = qkv_grad_ref[:, :, 1]
878-
v_grad_ref = qkv_grad_ref[:, :, 2]
879-
880-
else:
881-
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
882-
cross_attn_kv_input, attn_mask, grad_out)
883-
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
884-
cross_attn_kv_input, attn_mask,
885-
grad_out)
886-
887-
valid_out_ref = paddle.full_like(out_ref, 0)
888-
for i in range(0, bs):
889-
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
858+
return out, _q.grad, _k.grad, _v.grad
890859

891-
k_grad = kv_grad[:, :, 0]
892-
v_grad = kv_grad[:, :, 1]
893-
k_grad_ref = kv_grad_ref[:, :, 0]
894-
v_grad_ref = kv_grad_ref[:, :, 1]
860+
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
861+
attn_v_input, attn_mask, grad_out)
862+
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
863+
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
864+
valid_out_ref = paddle.full_like(out_ref, 0)
865+
for i in range(0, bs):
866+
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
895867

896868
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
897869
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
@@ -910,17 +882,18 @@ def calc_attn_output_and_grad(layer, q, kv, mask, dout):
910882

911883

912884
@pytest.mark.parametrize('bs', [1, 2, 8])
885+
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
913886
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
914-
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
887+
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
915888
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
916889
@pytest.mark.parametrize('no_wgrad', [True, False])
917890
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
918891
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
919892
@pytest.mark.parametrize('output_layernorm', [True, False])
920893
@pytest.mark.parametrize('return_layernorm_output', [True, False])
921-
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
922-
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
923-
output_layernorm, return_layernorm_output):
894+
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
895+
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
896+
math_dtype, output_layernorm, return_layernorm_output):
924897
"""
925898
Test Transformer Encoder Layer
926899
"""
@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
932905
# Skip if cuDNN fused attention is not supported
933906
if not is_fused_attention_supported(
934907
num_heads=num_heads,
935-
num_gqa_groups=num_heads,
908+
num_gqa_groups=num_gqa_groups,
936909
q_seqlen=q_seqlen,
937910
kv_seqlen=kv_seqlen,
938911
head_size=hidden_size // num_heads,
939912
dtype=math_dtype,
940913
dropout=0.0,
941-
qkv_layout="bs3hd",
914+
qkv_layout="bshd_bshd_bshd",
942915
bias_type="no_bias",
943916
mask_type=mask_type,
944917
):
@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
962935
layer_te = te.TransformerLayer(hidden_size,
963936
ffn_hidden_size,
964937
num_heads,
938+
num_gqa_groups=num_gqa_groups,
965939
layernorm_epsilon=eps,
966940
hidden_dropout=0.0,
967941
attention_dropout=0.0,
@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
975949
layer_pd = te.TransformerLayer(hidden_size,
976950
ffn_hidden_size,
977951
num_heads,
952+
num_gqa_groups=num_gqa_groups,
978953
layernorm_epsilon=eps,
979954
hidden_dropout=0.0,
980955
attention_dropout=0.0,
@@ -1088,18 +1063,19 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
10881063

10891064

10901065
@pytest.mark.parametrize('bs', [1, 2, 8])
1066+
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
10911067
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
1092-
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
1068+
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
10931069
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
10941070
@pytest.mark.parametrize('no_wgrad', [True, False])
10951071
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
10961072
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
10971073
@pytest.mark.parametrize('output_layernorm', [True, False])
10981074
@pytest.mark.parametrize('return_layernorm_output', [True, False])
10991075
@pytest.mark.parametrize('recompute_core_attention', [True, False])
1100-
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
1101-
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
1102-
output_layernorm, return_layernorm_output,
1076+
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
1077+
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
1078+
math_dtype, output_layernorm, return_layernorm_output,
11031079
recompute_core_attention):
11041080
"""
11051081
Test Transformer Decoder Layer
@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
11121088
# Skip if cuDNN fused attention is not supported
11131089
if not is_fused_attention_supported(
11141090
num_heads=num_heads,
1115-
num_gqa_groups=num_heads,
1091+
num_gqa_groups=num_gqa_groups,
11161092
q_seqlen=q_seqlen,
11171093
kv_seqlen=kv_seqlen,
11181094
head_size=hidden_size // num_heads,
11191095
dtype=math_dtype,
11201096
dropout=0.0,
1121-
qkv_layout="bs3hd",
1122-
bias_type="no_bias",
1123-
mask_type=mask_type,
1124-
):
1125-
pytest.skip("cuDNN fused attention is not supported")
1126-
if not is_fused_attention_supported(
1127-
head_size=hidden_size // num_heads,
1128-
num_heads=num_heads,
1129-
num_gqa_groups=num_heads,
1130-
q_seqlen=q_seqlen,
1131-
kv_seqlen=kv_seqlen,
1132-
dtype=math_dtype,
1133-
dropout=0.0,
1134-
qkv_layout="bshd_bs2hd",
1097+
qkv_layout="bshd_bshd_bshd",
11351098
bias_type="no_bias",
11361099
mask_type=mask_type,
11371100
):
11381101
pytest.skip("cuDNN fused attention is not supported")
11391102

1140-
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
1141-
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
1103+
encoder_input = paddle.normal(mean=0.0, std=0.1,
1104+
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
1105+
encoder_output = paddle.normal(mean=0.0, std=0.1,
1106+
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
11421107

11431108
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
11441109
kv_actual_seqlen = q_actual_seqlen
11451110
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
11461111

1147-
grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32')
1112+
grad_out = paddle.normal(mean=0.0, std=0.01,
1113+
shape=(bs, q_seqlen, hidden_size)).astype('float32')
1114+
1115+
# rounding to avoid numerical issues
1116+
encoder_input = paddle.round(encoder_input * 1000) / 1000
1117+
encoder_output = paddle.round(encoder_output * 1000) / 1000
1118+
grad_out = paddle.round(grad_out * 1000) / 1000
1119+
11481120
for i in range(0, bs):
11491121
grad_out[i, q_actual_seqlen[i]:, :] = 0
11501122
grad_out = grad_out.astype(math_dtype)
@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
11551127
layer_te = te.TransformerLayer(hidden_size,
11561128
ffn_hidden_size,
11571129
num_heads,
1130+
num_gqa_groups=num_gqa_groups,
11581131
layernorm_epsilon=eps,
11591132
hidden_dropout=0.0,
11601133
attention_dropout=0.0,
@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
11681141
layer_pd = te.TransformerLayer(hidden_size,
11691142
ffn_hidden_size,
11701143
num_heads,
1144+
num_gqa_groups=num_gqa_groups,
11711145
layernorm_epsilon=eps,
11721146
hidden_dropout=0.0,
11731147
attention_dropout=0.0,
@@ -1319,7 +1293,7 @@ def calc_transformer_output_and_grad(layer,
13191293
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
13201294
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
13211295
rtol=rtol,
1322-
atol=0.1)
1296+
atol=atol)
13231297
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
13241298
layer_pd.inter_attention.layernorm_query.weight.grad.T,
13251299
rtol=rtol,
@@ -1328,7 +1302,7 @@ def calc_transformer_output_and_grad(layer,
13281302
if output_layernorm:
13291303
assert_allclose(layer_te.self_attention.qkv.bias.grad,
13301304
layer_pd.self_attention.qkv.bias.grad,
1331-
rtol=0.01,
1305+
rtol=0.5,
13321306
atol=0.6)
13331307
else:
13341308
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,

0 commit comments

Comments
 (0)