-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Add support for cuDNN FusedAttention + THD + CP #885
Conversation
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
/te-ci pytorch |
Thanks for submitting the PR. Could you use our template to fill in the PR description please?
|
Signed-off-by: Xiaowei Ren <[email protected]>
… into xren/cp_thd
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
/te-ci pytorch |
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren <[email protected]> * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren <[email protected]> * remove a thd assert Signed-off-by: Xiaowei Ren <[email protected]> * fix bias for thd test Signed-off-by: Xiaowei Ren <[email protected]> * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren <[email protected]> * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren <[email protected]> * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren <[email protected]> * fix seq_offsets inputs Signed-off-by: Xiaowei Ren <[email protected]> * remove two comments Signed-off-by: Xiaowei Ren <[email protected]> * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren <[email protected]> * fix attn_mask_type check Signed-off-by: Xiaowei Ren <[email protected]> * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren <[email protected]> * fix a typo Signed-off-by: Xiaowei Ren <[email protected]> * fix out dout in bwd Signed-off-by: Xiaowei Ren <[email protected]> * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren <[email protected]> * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * change cp test batch size to 2 Signed-off-by: Xiaowei Ren <[email protected]> * fix code format Signed-off-by: Xiaowei Ren <[email protected]> * fix two assert info Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comment Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comments Signed-off-by: Xiaowei Ren <[email protected]> * minor fix Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comments Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Boxiang Wang <[email protected]>
Hi, thank you for great works. Could you tell me which commit hash is to support flash-attn and CP+THD, please? |
It has been done way back. Here is the PR. Now, THD + CP should work for both Tri Dao's Flash Attention and cuDNN Fused Attention. You can set I do not totally understand your second question. With THD format, CP implementation requires you to split each individual sequence into CP*2 chunks, and assign 2 chunks to each GPU for load balancing (like what we did with BSHD/SBHD format). You can refer here for an example. |
Hi there, thank you for your kindness reply. Your repo helps me very much. |
We do not gather tensor explicitly, we do KV P2P in ring topology so that they can be overlapped. |
Sorry for the naive question. elif qkv_format == "thd":
piece_q = q_.shape[0] // (world_size)
seq_idx_q = torch.arange(piece_q * rank, piece_q * (rank + 1)).cuda()
piece_kv = k_.shape[0] // (world_size)
seq_idx_kv = torch.arange(piece_kv * rank, piece_kv * (rank + 1)).cuda()
# seq_idx_q = tex.thd_get_partitioned_indices(
# cu_seqlens_q_padded, q_.shape[0], world_size, rank
# )
# seq_idx_kv = tex.thd_get_partitioned_indices(
# cu_seqlens_kv_padded, k_.shape[0], world_size, rank
# ) |
Current CP implementation only can support the load balancing strategy. Do you have fixed F, H, W? If you do, why don't you change your format to SBDH or BSHD? Or why do you have to avoid the load balancing strategy? are you doing bi-directional attention? The load balancing strategy also works with bi-directional attention. |
We want to support [FHW,H D] CP with thd format.
|
Load balancing also can allow you do rope, right? |
In the case of load balancing, one tensor on each rank has two offsets. which is not supported for our current api. |
OK, then sorry, our current implementation only can work with load balancing split. To meet your requirements, you probably need to hack AttnFuncWithCP. |
Thank you very much! |
Thank you very much, again!. |
This is a question very hard to answer, different people have different feeling of the difficulty. The implementation is at here. You can read it and try to estimate the difficulty by yourself. It's mainly pytorch code, but THD format indeed has some cuda code. |
Hi, ren, I'm back, again. Gradly, I‘ve test the idea, which does work. Thank you very much. |
If you do bidirectional attention (no_mask or padding_no_mask) and there is not padding tokens between each sequence, then current code should work for you even though you do not split sequence in load balancing way. Otherwise, you probably will encounter some issues. But sounds like you already verified your idea. Glad you find the solution. |
Description
Add support for cuDNN FA+THD+CP
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: