Skip to content

Commit

Permalink
Add max_seqlen_q and max_seqlen_kv when running unit test with cp
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Apr 22, 2024
1 parent a41a108 commit 43add29
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())

out_ = core_attn(q_, k_, v_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv

out_ = core_attn(q_, k_, v_,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv)
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
Expand Down

0 comments on commit 43add29

Please sign in to comment.