Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add support for cuDNN FusedAttention + THD + CP #885

Merged
merged 30 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d9ec133
add seq_offsets_qkvo for cudnn thd
xrennvidia May 31, 2024
9211a7c
add seq_offsets_qkvo to AttnFuncWithCP
xrennvidia May 31, 2024
541268b
fix seq_offsets calculation of cudnn thd
xrennvidia May 31, 2024
eb3f66d
remove a thd assert
xrennvidia May 31, 2024
9ac207a
fix bias for thd test
xrennvidia May 31, 2024
4239df9
add thd test for cudnn FA with CP
xrennvidia May 31, 2024
9b309ab
skip GQA/MQA test for cuDNN THD
xrennvidia May 31, 2024
4891fac
make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
xrennvidia May 31, 2024
e886136
fix seq_offsets inputs
xrennvidia Jun 1, 2024
6ae7384
remove two comments
xrennvidia Jun 1, 2024
4a38798
fix attn mask type for cudnn thd with cp
xrennvidia Jun 1, 2024
a8807e6
fix attn_mask_type check
xrennvidia Jun 1, 2024
5a6a0a5
fix attn_mask_type for cudnn fa with thd
xrennvidia Jun 1, 2024
d657950
fix a typo
xrennvidia Jun 2, 2024
3b8780c
fix out dout in bwd
xrennvidia Jun 3, 2024
b3997ce
assert cudnn+thd does not support attn bias
xrennvidia Jun 3, 2024
e7b9ea7
check if attn_mask_type has padding
xrennvidia Jun 3, 2024
df2e164
Merge branch 'main' into xren/cp_thd
xrennvidia Jun 3, 2024
12cd072
minor change
xrennvidia Jun 3, 2024
cc3eb4f
change cp test batch size to 2
xrennvidia Jun 4, 2024
442f347
fix code format
xrennvidia Jun 4, 2024
70042d2
fix two assert info
xrennvidia Jun 5, 2024
d37b399
fix assert comment
xrennvidia Jun 5, 2024
5302419
Merge branch 'main' into xren/cp_thd
cyanguwa Jun 5, 2024
a6c620d
fix assert comments
xrennvidia Jun 5, 2024
a98abd4
Merge branch 'xren/cp_thd' of github.com:xrennvidia/TransformerEngine…
xrennvidia Jun 5, 2024
dfe2f03
minor fix
xrennvidia Jun 5, 2024
a8cebd7
fix assert comments
xrennvidia Jun 6, 2024
f22bc19
Merge branch 'main' into xren/cp_thd
cyanguwa Jun 6, 2024
b13e3aa
Merge branch 'main' into xren/cp_thd
cyanguwa Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"):
return

rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
Expand All @@ -45,6 +47,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=

assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"

if kernel_backend == 'FusedAttention' and qkv_format == 'thd':
if 'causal' in config.attn_mask_type:
config.attn_mask_type = 'padding_causal'
else:
config.attn_mask_type = 'padding'
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved

# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
Expand Down Expand Up @@ -112,24 +120,22 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout)

# run core_attn wit CP
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
bias_ = rest[0] if len(rest) else None
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
Expand Down Expand Up @@ -158,7 +164,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)

if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
Expand Down
26 changes: 13 additions & 13 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}

def get_bash_arguments(**kwargs):
Expand Down Expand Up @@ -47,21 +47,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):

model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}

@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
Expand Down
Loading
Loading