Skip to content

Commit

Permalink
Add assert check forr seqlen % (cp_size*2) == 0 and make function nam…
Browse files Browse the repository at this point in the history
…es clearer
  • Loading branch information
kunlunl committed Apr 22, 2024
1 parent 3dda0d1 commit dea2a00
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 84 deletions.
26 changes: 16 additions & 10 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
cu_seqlens_q = None
cu_seqlens_kv = None
else:
assert(config.max_seqlen_q % (world_size * 2) == 0)
elif qkv_format == "thd":
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
Expand All @@ -78,6 +77,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
Expand All @@ -94,7 +95,7 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v, cu_seqlens_q = cu_seqlens_q, cu_seqlens_kv = cu_seqlens_kv)
out = core_attn(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
out.backward(dout)

# run core_attn wit CP
Expand All @@ -106,20 +107,21 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
else:
elif qkv_format == "thd":
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())

if qkv_format == "bshd" or qkv_format == "sbhd":
out_ = core_attn(q_, k_, v_)
else:
out_ = core_attn(q_, k_, v_, cu_seqlens_q=cu_seqlens_q//world_size, cu_seqlens_kv=cu_seqlens_kv//world_size)
out_ = core_attn(q_, k_, v_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
Expand All @@ -137,10 +139,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
else:
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
Expand All @@ -160,11 +164,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
else:
elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols)
torch.testing.assert_close(dq_, dq, **tols)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

def main(**kwargs):
run_dpa_with_cp(**kwargs)
Expand Down
23 changes: 13 additions & 10 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
if thd:
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i%2] = tex.thd_read_half_tensor(kv_inputs[i%2], cu_seqlens_k, 0)
kv_inputs[i%2] = tex.thd_read_half_tensor(
kv_inputs[i%2], cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
Expand Down Expand Up @@ -660,8 +661,9 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn]
q_inputs[i%2] = \
q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
Expand Down Expand Up @@ -720,10 +722,10 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
softmax_lse_per_step[i-1])
else:
if thd:
tex.thd_lse_correction(softmax_lse,
softmax_lse_per_step[i-1],
cu_seqlens_q,
q.size(0))
tex.thd_second_half_lse_correction(softmax_lse,
softmax_lse_per_step[i-1],
cu_seqlens_q,
q.size(0))
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
Expand Down Expand Up @@ -801,10 +803,11 @@ def backward(ctx, dout):

if ctx.causal:
if ctx.thd:
softmax_lse_ = tex.thd_read_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
else:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = \
softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
Expand Down Expand Up @@ -922,7 +925,7 @@ def backward(ctx, dout):
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
# [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
Expand Down
16 changes: 8 additions & 8 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,23 +648,23 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor,
int half_idx
);

void thd_lse_correction(at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
int total_tokens
void thd_second_half_lse_correction(at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
int total_tokens
);

at::Tensor thd_read_half_lse(const at::Tensor &lse,
const at::Tensor &cu_seqlens,
int total_tokens
at::Tensor thd_read_second_half_lse(const at::Tensor &lse,
const at::Tensor &cu_seqlens,
int total_tokens
);

void thd_out_correction(at::Tensor &out,
const at::Tensor &out_per_step,
const at::Tensor &lse,
const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens,
bool is_half
bool only_second_half
);

void thd_grad_correction(at::Tensor &grad,
Expand Down
Loading

0 comments on commit dea2a00

Please sign in to comment.