3
3
# See LICENSE for license information.
4
4
"""Test TE Paddle Layer-level APIs"""
5
5
6
- import math
7
6
import os
8
7
from utils import assert_allclose , is_fused_attention_supported
9
8
@@ -785,7 +784,7 @@ def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activati
785
784
786
785
@pytest .mark .parametrize ('bs' , [1 , 2 , 8 ])
787
786
@pytest .mark .parametrize ('hidden_size, num_heads' , [[1024 , 16 ], [768 , 12 ]])
788
- @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[128 , 128 ], [512 , 512 ]])
787
+ @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[512 , 512 ], [1024 , 1024 ]])
789
788
@pytest .mark .parametrize ('attn_type' , ['self' , 'cross' ])
790
789
@pytest .mark .parametrize ('mask_type' , ['causal' , 'padding' ])
791
790
@pytest .mark .parametrize ('math_dtype' , ['bfloat16' , 'float16' ])
@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
808
807
head_size = head_size ,
809
808
dtype = math_dtype ,
810
809
dropout = 0.0 ,
811
- qkv_layout = "bs3hd" if attn_type == "self" else "bshd_bs2hd " ,
810
+ qkv_layout = "bshd_bshd_bshd " ,
812
811
bias_type = "no_bias" ,
813
812
mask_type = mask_type ,
814
813
):
815
814
pytest .skip ("cuDNN fused attention is not supported" )
816
815
817
- self_attn_qkv_input = paddle .normal (mean = 0.0 ,
818
- std = 0.02 ,
819
- shape = (bs , q_seqlen , 3 , num_heads ,
820
- head_size )).astype (math_dtype )
821
- cross_attn_q_input = paddle .normal (mean = 0.0 ,
822
- std = 0.02 ,
823
- shape = (bs , q_seqlen , num_heads ,
824
- head_size )).astype (math_dtype )
825
- cross_attn_kv_input = paddle .normal (mean = 0.0 ,
826
- std = 0.02 ,
827
- shape = (bs , kv_seqlen , 2 , num_heads ,
828
- head_size )).astype (math_dtype )
816
+ attn_q_input = paddle .normal (mean = 0.0 , std = 0.02 ,
817
+ shape = (bs , q_seqlen , num_heads , head_size )).astype (math_dtype )
818
+ attn_k_input = paddle .normal (mean = 0.0 , std = 0.02 ,
819
+ shape = (bs , kv_seqlen , num_heads , head_size )).astype (math_dtype )
820
+ attn_v_input = paddle .normal (mean = 0.0 , std = 0.02 ,
821
+ shape = (bs , kv_seqlen , num_heads , head_size )).astype (math_dtype )
829
822
830
823
q_actual_seqlen = paddle .randint (low = 20 , high = q_seqlen , shape = (bs ,), dtype = 'int32' )
831
824
kv_actual_seqlen = paddle .randint (low = 20 , high = kv_seqlen , shape = (bs ,),
@@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
841
834
for i in range (0 , bs ):
842
835
attn_mask [i , 0 , 0 :q_actual_seqlen [i ], 0 :kv_actual_seqlen [i ]] = False
843
836
844
- norm_factor = math .sqrt (hidden_size // num_heads )
845
- layer_te = te .DotProductAttention (norm_factor ,
837
+ head_size = hidden_size // num_heads
838
+ layer_te = te .DotProductAttention (num_heads ,
839
+ head_size ,
846
840
attention_dropout = 0.0 ,
847
841
attn_mask_type = mask_type ,
848
842
attention_type = attn_type ,
849
843
backend = 'transformer_engine' )
850
- layer_pd = te .DotProductAttention (norm_factor ,
844
+ layer_pd = te .DotProductAttention (num_heads ,
845
+ head_size ,
851
846
attention_dropout = 0.0 ,
852
847
attn_mask_type = mask_type ,
853
848
attention_type = attn_type ,
854
849
backend = 'paddle' )
855
850
856
- def calc_attn_output_and_grad (layer , q , kv , mask , dout ):
851
+ def calc_attn_output_and_grad (layer , q , k , v , mask , dout ):
857
852
_q = paddle .to_tensor (q , stop_gradient = False )
858
- _kv = paddle .to_tensor (kv , stop_gradient = False ) if kv is not None else None
853
+ _k = paddle .to_tensor (k , stop_gradient = False )
854
+ _v = paddle .to_tensor (v , stop_gradient = False )
859
855
860
- out = layer (_q , _kv , mask )
856
+ out = layer (_q , _k , _v , mask )
861
857
out .backward (dout )
862
- return out , _q .grad , _kv .grad if _kv is not None else None
863
-
864
- if attn_type == 'self' :
865
- out , qkv_grad , _ = calc_attn_output_and_grad (layer_te , self_attn_qkv_input , None , attn_mask ,
866
- grad_out )
867
- out_ref , qkv_grad_ref , _ = calc_attn_output_and_grad (layer_pd , self_attn_qkv_input , None ,
868
- attn_mask , grad_out )
869
- valid_out_ref = paddle .full_like (out_ref , 0 )
870
- for i in range (0 , bs ):
871
- valid_out_ref [i , 0 :q_actual_seqlen [i ], :, :] = out_ref [i , 0 :q_actual_seqlen [i ], :, :]
872
-
873
- q_grad = qkv_grad [:, :, 0 ]
874
- k_grad = qkv_grad [:, :, 1 ]
875
- v_grad = qkv_grad [:, :, 2 ]
876
- q_grad_ref = qkv_grad_ref [:, :, 0 ]
877
- k_grad_ref = qkv_grad_ref [:, :, 1 ]
878
- v_grad_ref = qkv_grad_ref [:, :, 2 ]
879
-
880
- else :
881
- out , q_grad , kv_grad = calc_attn_output_and_grad (layer_te , cross_attn_q_input ,
882
- cross_attn_kv_input , attn_mask , grad_out )
883
- out_ref , q_grad_ref , kv_grad_ref = calc_attn_output_and_grad (layer_pd , cross_attn_q_input ,
884
- cross_attn_kv_input , attn_mask ,
885
- grad_out )
886
-
887
- valid_out_ref = paddle .full_like (out_ref , 0 )
888
- for i in range (0 , bs ):
889
- valid_out_ref [i , 0 :q_actual_seqlen [i ], :, :] = out_ref [i , 0 :q_actual_seqlen [i ], :, :]
858
+ return out , _q .grad , _k .grad , _v .grad
890
859
891
- k_grad = kv_grad [:, :, 0 ]
892
- v_grad = kv_grad [:, :, 1 ]
893
- k_grad_ref = kv_grad_ref [:, :, 0 ]
894
- v_grad_ref = kv_grad_ref [:, :, 1 ]
860
+ out , q_grad , k_grad , v_grad = calc_attn_output_and_grad (layer_te , attn_q_input , attn_k_input ,
861
+ attn_v_input , attn_mask , grad_out )
862
+ out_ref , q_grad_ref , k_grad_ref , v_grad_ref = calc_attn_output_and_grad (
863
+ layer_pd , attn_q_input , attn_k_input , attn_v_input , attn_mask , grad_out )
864
+ valid_out_ref = paddle .full_like (out_ref , 0 )
865
+ for i in range (0 , bs ):
866
+ valid_out_ref [i , 0 :q_actual_seqlen [i ], :, :] = out_ref [i , 0 :q_actual_seqlen [i ], :, :]
895
867
896
868
valid_q_grad_ref = paddle .full_like (q_grad_ref , 0 )
897
869
valid_k_grad_ref = paddle .full_like (k_grad_ref , 0 )
@@ -910,17 +882,18 @@ def calc_attn_output_and_grad(layer, q, kv, mask, dout):
910
882
911
883
912
884
@pytest .mark .parametrize ('bs' , [1 , 2 , 8 ])
885
+ @pytest .mark .parametrize ('num_gqa_groups' , [1 , 4 , 16 ])
913
886
@pytest .mark .parametrize ('hidden_size, num_heads, ffn_hidden_size' , [[1024 , 16 , 4096 ]])
914
- @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[128 , 128 ], [512 , 512 ]])
887
+ @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[512 , 512 ], [1024 , 1024 ]])
915
888
@pytest .mark .parametrize ('has_bias, no_dbias' , [[False , True ], [True , True ], [True , False ]])
916
889
@pytest .mark .parametrize ('no_wgrad' , [True , False ])
917
890
@pytest .mark .parametrize ('mask_type' , ['causal' , 'padding' ])
918
891
@pytest .mark .parametrize ('math_dtype' , ['bfloat16' , 'float16' ])
919
892
@pytest .mark .parametrize ('output_layernorm' , [True , False ])
920
893
@pytest .mark .parametrize ('return_layernorm_output' , [True , False ])
921
- def test_transformer_encoder_layer (bs , hidden_size , num_heads , ffn_hidden_size , has_bias , no_dbias ,
922
- no_wgrad , q_seqlen , kv_seqlen , mask_type , math_dtype ,
923
- output_layernorm , return_layernorm_output ):
894
+ def test_transformer_encoder_layer (bs , hidden_size , num_heads , num_gqa_groups , ffn_hidden_size ,
895
+ has_bias , no_dbias , no_wgrad , q_seqlen , kv_seqlen , mask_type ,
896
+ math_dtype , output_layernorm , return_layernorm_output ):
924
897
"""
925
898
Test Transformer Encoder Layer
926
899
"""
@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
932
905
# Skip if cuDNN fused attention is not supported
933
906
if not is_fused_attention_supported (
934
907
num_heads = num_heads ,
935
- num_gqa_groups = num_heads ,
908
+ num_gqa_groups = num_gqa_groups ,
936
909
q_seqlen = q_seqlen ,
937
910
kv_seqlen = kv_seqlen ,
938
911
head_size = hidden_size // num_heads ,
939
912
dtype = math_dtype ,
940
913
dropout = 0.0 ,
941
- qkv_layout = "bs3hd " ,
914
+ qkv_layout = "bshd_bshd_bshd " ,
942
915
bias_type = "no_bias" ,
943
916
mask_type = mask_type ,
944
917
):
@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
962
935
layer_te = te .TransformerLayer (hidden_size ,
963
936
ffn_hidden_size ,
964
937
num_heads ,
938
+ num_gqa_groups = num_gqa_groups ,
965
939
layernorm_epsilon = eps ,
966
940
hidden_dropout = 0.0 ,
967
941
attention_dropout = 0.0 ,
@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
975
949
layer_pd = te .TransformerLayer (hidden_size ,
976
950
ffn_hidden_size ,
977
951
num_heads ,
952
+ num_gqa_groups = num_gqa_groups ,
978
953
layernorm_epsilon = eps ,
979
954
hidden_dropout = 0.0 ,
980
955
attention_dropout = 0.0 ,
@@ -1088,18 +1063,19 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
1088
1063
1089
1064
1090
1065
@pytest .mark .parametrize ('bs' , [1 , 2 , 8 ])
1066
+ @pytest .mark .parametrize ('num_gqa_groups' , [1 , 4 , 16 ])
1091
1067
@pytest .mark .parametrize ('hidden_size, num_heads, ffn_hidden_size' , [[1024 , 16 , 4096 ]])
1092
- @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[128 , 128 ], [512 , 512 ]])
1068
+ @pytest .mark .parametrize ('q_seqlen, kv_seqlen' , [[512 , 512 ], [1024 , 1024 ]])
1093
1069
@pytest .mark .parametrize ('has_bias, no_dbias' , [[False , True ], [True , True ], [True , False ]])
1094
1070
@pytest .mark .parametrize ('no_wgrad' , [True , False ])
1095
1071
@pytest .mark .parametrize ('mask_type' , ['causal' , 'padding' ])
1096
1072
@pytest .mark .parametrize ('math_dtype' , ['bfloat16' , 'float16' ])
1097
1073
@pytest .mark .parametrize ('output_layernorm' , [True , False ])
1098
1074
@pytest .mark .parametrize ('return_layernorm_output' , [True , False ])
1099
1075
@pytest .mark .parametrize ('recompute_core_attention' , [True , False ])
1100
- def test_transformer_decoder_layer (bs , hidden_size , num_heads , ffn_hidden_size , has_bias , no_dbias ,
1101
- no_wgrad , q_seqlen , kv_seqlen , mask_type , math_dtype ,
1102
- output_layernorm , return_layernorm_output ,
1076
+ def test_transformer_decoder_layer (bs , hidden_size , num_heads , num_gqa_groups , ffn_hidden_size ,
1077
+ has_bias , no_dbias , no_wgrad , q_seqlen , kv_seqlen , mask_type ,
1078
+ math_dtype , output_layernorm , return_layernorm_output ,
1103
1079
recompute_core_attention ):
1104
1080
"""
1105
1081
Test Transformer Decoder Layer
@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
1112
1088
# Skip if cuDNN fused attention is not supported
1113
1089
if not is_fused_attention_supported (
1114
1090
num_heads = num_heads ,
1115
- num_gqa_groups = num_heads ,
1091
+ num_gqa_groups = num_gqa_groups ,
1116
1092
q_seqlen = q_seqlen ,
1117
1093
kv_seqlen = kv_seqlen ,
1118
1094
head_size = hidden_size // num_heads ,
1119
1095
dtype = math_dtype ,
1120
1096
dropout = 0.0 ,
1121
- qkv_layout = "bs3hd" ,
1122
- bias_type = "no_bias" ,
1123
- mask_type = mask_type ,
1124
- ):
1125
- pytest .skip ("cuDNN fused attention is not supported" )
1126
- if not is_fused_attention_supported (
1127
- head_size = hidden_size // num_heads ,
1128
- num_heads = num_heads ,
1129
- num_gqa_groups = num_heads ,
1130
- q_seqlen = q_seqlen ,
1131
- kv_seqlen = kv_seqlen ,
1132
- dtype = math_dtype ,
1133
- dropout = 0.0 ,
1134
- qkv_layout = "bshd_bs2hd" ,
1097
+ qkv_layout = "bshd_bshd_bshd" ,
1135
1098
bias_type = "no_bias" ,
1136
1099
mask_type = mask_type ,
1137
1100
):
1138
1101
pytest .skip ("cuDNN fused attention is not supported" )
1139
1102
1140
- encoder_input = paddle .uniform (shape = (bs , q_seqlen , hidden_size ), dtype = math_dtype )
1141
- encoder_output = paddle .uniform (shape = (bs , kv_seqlen , hidden_size ), dtype = math_dtype )
1103
+ encoder_input = paddle .normal (mean = 0.0 , std = 0.1 ,
1104
+ shape = (bs , q_seqlen , hidden_size )).astype (math_dtype )
1105
+ encoder_output = paddle .normal (mean = 0.0 , std = 0.1 ,
1106
+ shape = (bs , kv_seqlen , hidden_size )).astype (math_dtype )
1142
1107
1143
1108
q_actual_seqlen = paddle .ones (shape = (bs ,), dtype = 'int32' ) * q_seqlen
1144
1109
kv_actual_seqlen = q_actual_seqlen
1145
1110
attn_mask = paddle .ones (shape = (bs , 1 , q_seqlen , kv_seqlen ), dtype = 'bool' )
1146
1111
1147
- grad_out = paddle .normal (mean = 0.0 , std = 0.2 , shape = (bs , q_seqlen , hidden_size )).astype ('float32' )
1112
+ grad_out = paddle .normal (mean = 0.0 , std = 0.01 ,
1113
+ shape = (bs , q_seqlen , hidden_size )).astype ('float32' )
1114
+
1115
+ # rounding to avoid numerical issues
1116
+ encoder_input = paddle .round (encoder_input * 1000 ) / 1000
1117
+ encoder_output = paddle .round (encoder_output * 1000 ) / 1000
1118
+ grad_out = paddle .round (grad_out * 1000 ) / 1000
1119
+
1148
1120
for i in range (0 , bs ):
1149
1121
grad_out [i , q_actual_seqlen [i ]:, :] = 0
1150
1122
grad_out = grad_out .astype (math_dtype )
@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
1155
1127
layer_te = te .TransformerLayer (hidden_size ,
1156
1128
ffn_hidden_size ,
1157
1129
num_heads ,
1130
+ num_gqa_groups = num_gqa_groups ,
1158
1131
layernorm_epsilon = eps ,
1159
1132
hidden_dropout = 0.0 ,
1160
1133
attention_dropout = 0.0 ,
@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
1168
1141
layer_pd = te .TransformerLayer (hidden_size ,
1169
1142
ffn_hidden_size ,
1170
1143
num_heads ,
1144
+ num_gqa_groups = num_gqa_groups ,
1171
1145
layernorm_epsilon = eps ,
1172
1146
hidden_dropout = 0.0 ,
1173
1147
attention_dropout = 0.0 ,
@@ -1319,7 +1293,7 @@ def calc_transformer_output_and_grad(layer,
1319
1293
assert_allclose (layer_te .self_attention .layernorm_qkv .weight .grad ,
1320
1294
layer_pd .self_attention .layernorm_qkv .weight .grad .T ,
1321
1295
rtol = rtol ,
1322
- atol = 0.1 )
1296
+ atol = atol )
1323
1297
assert_allclose (layer_te .inter_attention .layernorm_query .weight .grad ,
1324
1298
layer_pd .inter_attention .layernorm_query .weight .grad .T ,
1325
1299
rtol = rtol ,
@@ -1328,7 +1302,7 @@ def calc_transformer_output_and_grad(layer,
1328
1302
if output_layernorm :
1329
1303
assert_allclose (layer_te .self_attention .qkv .bias .grad ,
1330
1304
layer_pd .self_attention .qkv .bias .grad ,
1331
- rtol = 0.01 ,
1305
+ rtol = 0.5 ,
1332
1306
atol = 0.6 )
1333
1307
else :
1334
1308
assert_allclose (layer_te .self_attention .layernorm_qkv .bias .grad ,
0 commit comments