Skip to content

Commit

Permalink
Add THD format support for Context Parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Apr 26, 2024
1 parent 07bf4ac commit 7aea011
Show file tree
Hide file tree
Showing 6 changed files with 872 additions and 70 deletions.
79 changes: 65 additions & 14 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_extensions as tex
from test_fused_attn_with_cp import model_configs

dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
Expand Down Expand Up @@ -58,12 +59,27 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
Expand All @@ -72,24 +88,45 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_kv]:
dist.broadcast(x, 0, group=cp_comm_group)

# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out = core_attn(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
out.backward(dout)

# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
if qkv_format == "bshd" or qkv_format == "sbhd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)

max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv

out_ = core_attn(q_, k_, v_,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv)
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
Expand All @@ -100,11 +137,20 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]

if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
Expand All @@ -123,6 +169,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols)
torch.testing.assert_close(dq_, dq, **tols)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_bash_arguments(**kwargs):
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
Expand Down
Loading

0 comments on commit 7aea011

Please sign in to comment.