Skip to content

Commit

Permalink
[C/PyTorch] Add max_t support for THD (#1244)
Browse files Browse the repository at this point in the history
* WIP: add max_t support for THD

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: save tensors for debug and point to new FE

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix stats in bwd

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix stats in fwd

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add docstring for DPA

Signed-off-by: Charlene Yang <[email protected]>

* add docstring

Signed-off-by: Charlene Yang <[email protected]>

* WIP: first try on adding max_b and max_t

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit c3d522e9f5aef3c8ddfec5bf6ff24c3db97bb059.

Signed-off-by: Charlene Yang <[email protected]>

* Revert "WIP: first try on adding max_b and max_t"

This reverts commit 3bc01eb.

Signed-off-by: Charlene Yang <[email protected]>

* update docstring and fix max_seqlen logic for thd

Signed-off-by: Charlene Yang <[email protected]>

* revert two lines of change in docstring

Signed-off-by: Charlene Yang <[email protected]>

* WIP: add get_max_b/t

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix max_seqlen code and docstring

Signed-off-by: Charlene Yang <[email protected]>

* sucess: add max_b/max_t

Signed-off-by: Charlene Yang <[email protected]>

* remove debug code

Signed-off-by: Charlene Yang <[email protected]>

* change max_b/max_t buckets

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix b vs orig_b

Signed-off-by: Charlene Yang <[email protected]>

* fix b vs orig_b with 0 fill

Signed-off-by: Charlene Yang <[email protected]>

* update FE for T3HD/TH3D

Signed-off-by: Charlene Yang <[email protected]>

* add max_b to conversion kernels

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix changes after last merge

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add Jax support for max_t

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE to 1.8.0-rc

Signed-off-by: Charlene Yang <[email protected]>

* update FE to 1.8.0

Signed-off-by: Charlene Yang <[email protected]>

* code review/formating fixes

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix Stats shape for <9.6

Signed-off-by: Charlene Yang <[email protected]>

* return nullptr for offset_stats when cudnn < 9.6

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more version control

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 25, 2024
1 parent 83f9cc0 commit 7fb22c3
Show file tree
Hide file tree
Showing 10 changed files with 584 additions and 248 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 146 files
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(1, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
Expand Down
68 changes: 53 additions & 15 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
size_t t = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t = input_QKV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
Expand All @@ -292,7 +297,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
stream, handle);
Expand Down Expand Up @@ -349,6 +354,11 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
size_t t = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t = input_QKV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
Expand Down Expand Up @@ -377,7 +387,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
Expand Down Expand Up @@ -442,6 +452,13 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_KV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
Expand All @@ -463,9 +480,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
Expand Down Expand Up @@ -526,6 +543,13 @@ void nvte_fused_attn_bwd_kvpacked(
} else {
NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!");
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_KV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
Expand Down Expand Up @@ -556,9 +580,9 @@ void nvte_fused_attn_bwd_kvpacked(
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV,
input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
Expand Down Expand Up @@ -616,6 +640,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_K->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
Expand All @@ -637,9 +668,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
#else
Expand Down Expand Up @@ -696,6 +727,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
t_kv = input_K->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
Expand Down Expand Up @@ -726,10 +764,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK,
output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
Expand Down
Loading

0 comments on commit 7fb22c3

Please sign in to comment.