From 7fb22c375804f77f4f95df3eab606c7bd3e80aed Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:29:56 -0700 Subject: [PATCH] [C/PyTorch] Add max_t support for THD (#1244) * WIP: add max_t support for THD Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: save tensors for debug and point to new FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix stats in bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix stats in fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add docstring for DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: first try on adding max_b and max_t Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit c3d522e9f5aef3c8ddfec5bf6ff24c3db97bb059. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "WIP: first try on adding max_b and max_t" This reverts commit 3bc01ebaf2aa846fd16634e2d33b0d0f5803a076. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update docstring and fix max_seqlen logic for thd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert two lines of change in docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: add get_max_b/t Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix max_seqlen code and docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * sucess: add max_b/max_t Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug code Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change max_b/max_t buckets Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix b vs orig_b Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix b vs orig_b with 0 fill Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE for T3HD/TH3D Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add max_b to conversion kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix changes after last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add Jax support for max_t Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.8.0-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 1.8.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * code review/formating fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Stats shape for <9.6 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return nullptr for offset_stats when cudnn < 9.6 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more version control Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/fused_attn/test_fused_attn.py | 2 +- .../common/fused_attn/fused_attn.cpp | 68 ++- .../fused_attn_f16_arbitrary_seqlen.cu | 497 ++++++++++++------ .../fused_attn_f16_arbitrary_seqlen.h | 63 +-- .../fused_attn_f16_max512_seqlen.cu | 4 +- transformer_engine/common/fused_attn/utils.cu | 98 +++- transformer_engine/common/fused_attn/utils.h | 13 +- .../jax/cpp_extensions/attention.py | 11 +- transformer_engine/pytorch/attention.py | 74 ++- 10 files changed, 584 insertions(+), 248 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 2533f5e5c1..936021bfed 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b +Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index a71827c6f4..0fd1e2590c 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -619,7 +619,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(1, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 2654273356..4ea0ea5741 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -272,6 +272,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -292,7 +297,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, + b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -349,6 +354,11 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -377,7 +387,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -442,6 +452,13 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -463,9 +480,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -526,6 +543,13 @@ void nvte_fused_attn_bwd_kvpacked( } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -556,9 +580,9 @@ void nvte_fused_attn_bwd_kvpacked( input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV, - input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -616,6 +640,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -637,9 +668,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -696,6 +727,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -726,10 +764,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, - output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, + output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 176ec50cd0..1a555a4999 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -49,14 +49,14 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, - void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -73,10 +73,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - if (is_ragged) { + const auto cudnn_runtime_version = cudnnGetVersion(); + + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; } - const auto cudnn_runtime_version = cudnnGetVersion(); const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { @@ -117,6 +125,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -140,30 +149,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -175,6 +164,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -268,6 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); O->set_output(true) .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) @@ -276,10 +285,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_ragged_offset(offset_stats); + } else { + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + } std::tuple, // Q std::shared_ptr, // K @@ -291,8 +314,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -302,15 +328,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, - offset_v, offset_o, dropout_seed, dropout_offset] = + offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -318,10 +345,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // We do this by adding padding at the end of each separate allocation. auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); - const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen; + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); - const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -348,7 +382,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -363,15 +397,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, - devOffsetsV, devOffsetsO); + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -386,12 +427,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, - void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -414,6 +456,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; + } // We choose between 32-bit and 64-bit offsets depending on need. // This allows us to support older cuDNN runtimes gracefully. @@ -462,6 +514,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -485,29 +538,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr q, k, v, o, dO, stats, attn_scale; std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -522,6 +556,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -569,11 +623,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_data_type(fe::DataType_t::FLOAT) + .set_ragged_offset(offset_stats)); + } else { + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -589,6 +658,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (is_ragged && cudnn_runtime_version >= 90600) { + sdpa_backward_options.set_max_total_seq_len_q(s_q); + } + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_sliding_window_length(window_size_left + 1); } @@ -682,8 +755,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -693,15 +769,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, + offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = + offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -709,10 +786,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( // We do this by adding padding at the end of each separate allocation. auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); - const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen; + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); - const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -752,7 +836,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -767,15 +851,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, - devOffsetsV, devOffsetsO); + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -792,10 +883,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -803,6 +894,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { stride = typeToSize(QKV_type) * num_attn_heads * head_dim; @@ -821,17 +913,30 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -845,7 +950,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -875,12 +984,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -898,10 +1007,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -909,7 +1018,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { @@ -934,6 +1042,14 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + void *devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQ = devPtrdQKV; void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); @@ -952,12 +1068,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -975,19 +1092,21 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( } void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; void *devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; @@ -1005,6 +1124,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -1013,12 +1133,26 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1032,7 +1166,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1063,11 +1201,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1086,12 +1224,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1122,6 +1261,16 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdKV = output_dKV->data.dptr; void *devPtrdK = devPtrdKV; @@ -1143,12 +1292,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1167,8 +1316,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1177,6 +1327,7 @@ void fused_attn_arbitrary_seqlen_fwd( using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; @@ -1196,12 +1347,26 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1215,7 +1380,11 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1246,11 +1415,11 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1269,13 +1438,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1296,6 +1465,16 @@ void fused_attn_arbitrary_seqlen_bwd( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdK = output_dK->data.dptr; void *devPtrdV = output_dV->data.dptr; @@ -1315,12 +1494,12 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4b523cca1a..3a1216f891 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -19,47 +19,50 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,13 +71,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 88c1490c01..d3422de481 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -746,7 +746,7 @@ void fused_attn_max_512_fwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1169,7 +1169,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 7f76dcad77..ca00218d9a 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include "../common.h" #include "transformer_engine/fused_attn.h" @@ -353,66 +354,75 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t * } // convert cu_seqlens to actual_seqlens -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b) { + if (tid < actual_b) { q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; + } else if (tid < max_b) { + q_seqlens[tid] = 0; + kv_seqlens[tid] = 0; } } // convert cu_seqlens_padded to offsets template -__device__ void cu_seqlens_padded_to_offsets_impl(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, - const int32_t *cu_seqlens_kv_padded, - OFFSETS_T *offsets_q, OFFSETS_T *offsets_k, - OFFSETS_T *offsets_v, OFFSETS_T *offsets_o) { +__device__ void cu_seqlens_padded_to_offsets_impl( + NVTE_QKV_Layout_Group layout_group, int64_t actual_b, int64_t max_b, int64_t h, int64_t hg, + int64_t d_qk, int64_t d_v, const int32_t *cu_seqlens_q_padded, + const int32_t *cu_seqlens_kv_padded, OFFSETS_T *offsets_q, OFFSETS_T *offsets_k, + OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b + 1) { - offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; + auto cu_seqlens_id = min(tid, actual_b); + if (tid <= max_b) { + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; + if (offsets_s != nullptr) { + offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = offsets_q[tid]; - offsets_v[tid] = offsets_q[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = offsets_q[cu_seqlens_id]; + offsets_v[tid] = offsets_q[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; break; } } } -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, const int32_t *cu_seqlens_kv_padded, DType offset_dtype, void *offsets_q, void *offsets_k, - void *offsets_v, void *offsets_o) { + void *offsets_v, void *offsets_o, void *offsets_s) { if (offset_dtype == DType::kInt32) { cu_seqlens_padded_to_offsets_impl( - layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), - reinterpret_cast(offsets_v), reinterpret_cast(offsets_o)); + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); } else { assert(offset_dtype == DType::kInt64 && "expect int64"); cu_seqlens_padded_to_offsets_impl( - layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), - reinterpret_cast(offsets_v), reinterpret_cast(offsets_o)); + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); } } @@ -450,6 +460,40 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at return DType::kInt32; } +// quantize batch size +size_t get_max_batch_size(size_t batch_size) { + size_t max_b = batch_size; + size_t log2_b = ceil(log2(batch_size)); + // batch size is expected to be 10s-100s + // b = 1, ..., 32 -> max_b = 32 + // b = 33, ..., 512 -> max_b = next power of 2 + // otherwise -> max_b = b + if (log2_b <= 5) { + max_b = 32; + } else if (log2_b <= 9) { + max_b = pow(2, log2_b); + } + return max_b; +} + +// quantize token count +size_t get_max_tokens(size_t num_tokens) { + // token count is expected to be 1k's-100k's + // t = 0, ..., 1024 -> max_t = 1024 + // t = 1025, ..., 32k -> max_t = next power of 2 + // t = 32k+1, ... -> max_t = increment by 32k + size_t log2_t = ceil(log2(num_tokens)); + size_t max_t = 0; + if (log2_t <= 10) { + max_t = 1024; + } else if (log2_t <= 15) { + max_t = pow(2, log2_t); + } else { + max_t = (num_tokens + 32767) / 32768 * 32768; + } + return max_t; +} + } // namespace fused_attn // get cuDNN data type diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index bea7ed05dd..c060c4907d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -122,21 +122,24 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t * int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens); -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, const int32_t *cu_seqlens_kv_padded, DType offset_dtype, void *offsets_q, void *offsets_k, - void *offsets_v, void *offsets_o); + void *offsets_v, void *offsets_o, void *offsets_s); DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads, int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t head_dim_qk, int64_t head_dim_v); +size_t get_max_batch_size(size_t batch_size); +size_t get_max_tokens(size_t num_tokens); } // namespace fused_attn cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index b236e19e57..8a1e0e2ad7 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -277,7 +277,16 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) + # cuDNN 9.6 reduces the required softmax shape + if get_cudnn_version() >= (9, 6, 0): + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) + else: + softmax_shape = ( + *batch_shape, + attn_heads, + q_max_seqlen, + config.max_segments_per_seq, + ) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5f8357a01b..2567542b87 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7671,6 +7671,60 @@ def forward( based on its internal logic. These optimizations trade memory for performance and should be used with care. + .. note:: + .. _cu_seqlens note: + + When training data has variable sequence lengths, users have two options. + + 1. Manipulate the data and pad all sequences to the same length. Use + :attr:`qkv_format` = {"bshd", "sbhd"} and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask` + (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide + the real sequence length information. For example, a batch of 3 sequences + [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative + sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + 2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`, + as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed + without any padding, and the sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + In certain use cases, a varying number of identifier tokens are inserted between + sequences. These tokens do not participate in the attention calculation. + :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified + in such cases to correctly identify the start and end of each sequence in a batch. + For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and + :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13] + for self-attention. + + .. note:: + .. _max_seqlen note: + + When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch. + :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of + :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will + infer them as such. + + When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and + :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch. + When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`. + This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this + overhead, users are recommended to obtain the maximum sequence lengths from the data loaders + and pass them in. + + - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch, + dynamic shapes need to be supported for tensor construction. FlashAttention and + UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static + to create graphs before performance heuristics analysis. To reduce the number of graphs created + per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size, + :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of + :attr:`query_layer`, "t" dimension of :attr:`key_layer`}. + Parameters ---------- query_layer : torch.Tensor @@ -7693,25 +7747,29 @@ def forward( cu_seqlens_q: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_q_padded = cu_seqlens_q`. + See :ref:`note` for more details. cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_kv_padded = cu_seqlens_kv`. + See :ref:`note` for more details. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q` if not provided. + See :ref:`note` for more details. max_seqlen_kv: Optional[int], default = `None` Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv` if not provided. + See :ref:`note` for more details. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', 'arbitrary'}, default = `None`. Type of attention mask passed into @@ -7902,6 +7960,7 @@ def forward( assert ( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] @@ -7914,7 +7973,6 @@ def forward( else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - batch_size = len(cu_seqlens_q) - 1 cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -7929,10 +7987,12 @@ def forward( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" if qkv_format == "sbhd": - max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[1] else: - max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[0] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size @@ -7941,13 +8001,13 @@ def forward( assert all( seqlens_q <= max_seqlen_q ), """Sequence lengths indicated by cu_seqlens_q must be no greater than - the sequence dimention in 'query_layer'!""" + the sequence dimension in 'query_layer'!""" if cu_seqlens_kv is not None: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] assert all( seqlens_kv <= max_seqlen_kv ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than - the sequence dimention in 'key_layer' and 'value_layer'!""" + the sequence dimension in 'key_layer' and 'value_layer'!""" if cu_seqlens_q is None or cu_seqlens_kv is None: if "padding" in attn_mask_type: assert (