Skip to content

Commit

Permalink
Update the unit test of uinsg THD format in Context Parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Mar 14, 2024
1 parent f59b6f3 commit 514b9b6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
cu_seqlens_q = None
cu_seqlens_kv = None
else:
seqlens_q = torch.randint(world_size*2, config.max_seqlen_q*2, [config.batch_size]).to(torch.int32)
assert(config.max_seqlen_q % (world_size * 2) == 0)
seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
Expand Down

0 comments on commit 514b9b6

Please sign in to comment.