Skip to content

Commit

Permalink
Rebase to the version that support bias with CP
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Apr 26, 2024
1 parent 0bc060e commit 8ced458
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
1 change: 1 addition & 0 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
cu_seqlens_q = cu_seqlens_q // world_size
cu_seqlens_kv = cu_seqlens_kv // world_size
bias_ = None
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
Expand Down
25 changes: 13 additions & 12 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

causal = (attn_mask_type == "causal")
thd = (qkv_format == "thd")

qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format

Expand Down Expand Up @@ -676,7 +675,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
if thd:
if qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i%2] = tex.thd_read_half_tensor(
kv_inputs[i%2], cu_seqlens_k, 0)
Expand Down Expand Up @@ -728,7 +727,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
if len(rest) > 0:
attn_biases[i] = rest[0]
else:
if thd:
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
Expand Down Expand Up @@ -792,7 +791,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and not thd:
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
Expand All @@ -801,7 +800,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
if thd:
if qkv_format == "thd":
tex.thd_second_half_lse_correction(softmax_lse,
softmax_lse_per_step[i-1],
cu_seqlens_q,
Expand All @@ -816,7 +815,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

softmax_lse = softmax_lse.to(torch.float)
seq_dim = qkv_format.index("s")
if qkv_format == "bshd" or qkv_format == "sbhd":
seq_dim = qkv_format.index("s")
for i in range(cp_size):
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
Expand Down Expand Up @@ -868,7 +868,6 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
out = out.view(-1, *out.shape[-2:])

ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.thd = thd
ctx.rng_states = rng_states
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
Expand Down Expand Up @@ -912,7 +911,7 @@ def backward(ctx, dout):
attn_dbias = None

if ctx.causal:
if ctx.thd:
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0))
else:
# [b, np, sq] -> [b, np, 2, sq//2]
Expand Down Expand Up @@ -1057,7 +1056,7 @@ def backward(ctx, dout):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
if ctx.thd:
if ctx.qkv_format == "thd":
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0)
else:
Expand Down Expand Up @@ -1111,7 +1110,7 @@ def backward(ctx, dout):
attn_bias_type=ctx.attn_bias_type,
)
else:
if ctx.thd:
if ctx.qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1)
else:
Expand All @@ -1121,7 +1120,7 @@ def backward(ctx, dout):
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
if ctx.thd:
if ctx.qkv_format == "thd":
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1)
else:
Expand Down Expand Up @@ -1322,10 +1321,12 @@ def attn_forward_func_with_cp(
use_fused_attention=False
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert(qkv_format in ["bshd", "sbhd"]
assert(qkv_format in ["bshd", "sbhd", "thd"]
), f"QKV format of {qkv_format} is not supported with context parallelism!"
assert(qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
assert(not(qkv_format == "thd" and use_fused_attention)
), "FusedAttention does not support thd format!"
assert (attn_mask_type in ["causal", "no_mask"]
), f"Mask type of {attn_mask_type} is not supported with context parallelism!"
assert (attn_bias is None or use_fused_attention
Expand Down

0 comments on commit 8ced458

Please sign in to comment.