Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add support for cuDNN FusedAttention + THD + CP #885

Merged
merged 30 commits into from
Jun 10, 2024

Conversation

xrennvidia
Copy link
Collaborator

@xrennvidia xrennvidia commented Jun 3, 2024

Description

Add support for cuDNN FA+THD+CP

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

Please list the changes introduced in this PR:

  • add CP support for THD format in FusedAttention, and add the related unit tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

xrennvidia added 19 commits May 31, 2024 11:36
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
@cyanguwa cyanguwa self-requested a review June 4, 2024 00:37
Signed-off-by: Xiaowei Ren <[email protected]>
@cyanguwa
Copy link
Collaborator

cyanguwa commented Jun 5, 2024

/te-ci pytorch

@cyanguwa
Copy link
Collaborator

cyanguwa commented Jun 5, 2024

Thanks for submitting the PR. Could you use our template to fill in the PR description please?

# Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

## Type of change

- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)

## Changes

Please list the changes introduced in this PR:

- Change A
- Change B

# Checklist:

- [ ] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)
- [ ] The functionality is complete
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes

@cyanguwa cyanguwa changed the title Add support for cuDNN FA+THD+CP [PyTorch] Add support for cuDNN FusedAttention + THD + CP Jun 5, 2024
@cyanguwa
Copy link
Collaborator

cyanguwa commented Jun 6, 2024

/te-ci pytorch

@cyanguwa
Copy link
Collaborator

cyanguwa commented Jun 7, 2024

/te-ci pytorch

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyanguwa cyanguwa merged commit f68df15 into NVIDIA:main Jun 10, 2024
20 checks passed
@xrennvidia xrennvidia deleted the xren/cp_thd branch June 10, 2024 19:45
BoxiangW pushed a commit to BoxiangW/TransformerEngine that referenced this pull request Jun 11, 2024
* add seq_offsets_qkvo for cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets calculation of cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove a thd assert

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias for thd test

Signed-off-by: Xiaowei Ren <[email protected]>

* add thd test for cudnn FA with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* skip GQA/MQA test for cuDNN THD

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets inputs

Signed-off-by: Xiaowei Ren <[email protected]>

* remove two comments

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn mask type for cudnn thd with cp

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type check

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type for cudnn fa with thd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix a typo

Signed-off-by: Xiaowei Ren <[email protected]>

* fix out dout in bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* assert cudnn+thd does not support attn bias

Signed-off-by: Xiaowei Ren <[email protected]>

* check if attn_mask_type has padding

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* change cp test batch size to 2

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

* fix two assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comment

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Signed-off-by: Boxiang Wang <[email protected]>
@wplf
Copy link
Contributor

wplf commented Aug 26, 2024

Hi, thank you for great works.

Could you tell me which commit hash is to support flash-attn and CP+THD, please?
And if I want to use a different strategy to splitting CP tensor, what should I do?

@xrennvidia
Copy link
Collaborator Author

Hi, thank you for great works.

Could you tell me which commit hash is to support flash-attn and CP+THD, please? And if I want to use a different strategy to splitting CP tensor, what should I do?

It has been done way back. Here is the PR. Now, THD + CP should work for both Tri Dao's Flash Attention and cuDNN Fused Attention. You can set NVTE_FLASH_ATTN and NVTE_FUSED_ATTN to select which backend to use.

I do not totally understand your second question. With THD format, CP implementation requires you to split each individual sequence into CP*2 chunks, and assign 2 chunks to each GPU for load balancing (like what we did with BSHD/SBHD format). You can refer here for an example.

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

Hi there, thank you for your kindness reply. Your repo helps me very much.
The second question is my tensor shape is [ F * H * W, H, D], I want to split tensor in H and W dim. So my strategy is different with LLM's load balancing strategy. Could you should show me how to gather tensor in another way before AttnFuncWithCP?

@xrennvidia
Copy link
Collaborator Author

Hi there, thank you for your kindness reply. Your repo helps me very much. The second question is my tensor shape is [ F * H * W, H, D], I want to split tensor in H and W dim. So my strategy is different with LLM's load balancing strategy. Could you should show me how to gather tensor in another way before AttnFuncWithCP?

We do not gather tensor explicitly, we do KV P2P in ring topology so that they can be overlapped.
We just added a new implementation that does true KV all-gather (refer here), but it does not support THD format yet.

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

Sorry for the naive question.
I saw the implementation of CP split is below. Could you please tell me how to change the load balancing strategy to naive strategy.
PS: I need to make sure that the tensor in each rank is intact.

    elif qkv_format == "thd":
        piece_q = q_.shape[0] // (world_size)
        seq_idx_q = torch.arange(piece_q * rank, piece_q * (rank + 1)).cuda()
        
        piece_kv = k_.shape[0] // (world_size)
        seq_idx_kv = torch.arange(piece_kv * rank, piece_kv * (rank + 1)).cuda()
        
        # seq_idx_q = tex.thd_get_partitioned_indices(
        #     cu_seqlens_q_padded, q_.shape[0], world_size, rank
        # )
        # seq_idx_kv = tex.thd_get_partitioned_indices(
        #     cu_seqlens_kv_padded, k_.shape[0], world_size, rank
        # )

@xrennvidia
Copy link
Collaborator Author

xrennvidia commented Aug 27, 2024

Current CP implementation only can support the load balancing strategy.

Do you have fixed F, H, W? If you do, why don't you change your format to SBDH or BSHD?

Or why do you have to avoid the load balancing strategy? are you doing bi-directional attention? The load balancing strategy also works with bi-directional attention.

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

We want to support [FHW,H D] CP with thd format.
The reason we want to avoid the load balancing strategy is we need to make tensor on each cp_rank is intact and get corresponding H or W or F offsets to compat with 2D/3D rope.

  • CP=2 [:, 0:H//2, 0:W, :] [:, H//2:H, 0:W, :]
  • CP=4 [:, 0:H//2, 0:W//2, :] [:, H//2:H, 0:W//2, :] [:, 0:H//2, W//2:W, :] [:, H//2:H, W//2:W, :]

@xrennvidia
Copy link
Collaborator Author

Load balancing also can allow you do rope, right?
You only need to make sure rope is initialized with full sequence, and then slice rope in the same way as input in each rank, right?

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

In the case of load balancing, one tensor on each rank has two offsets. which is not supported for our current api.

@xrennvidia
Copy link
Collaborator Author

OK, then sorry, our current implementation only can work with load balancing split. To meet your requirements, you probably need to hack AttnFuncWithCP.

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

Thank you very much!

@wplf
Copy link
Contributor

wplf commented Aug 27, 2024

Thank you very much, again!.
Could you please tell me the difficulty of hacking AttnFuncWithCP?
I am not familiar with cuda. If this is too hard for me , I will find another way to bypass this problem.

@xrennvidia
Copy link
Collaborator Author

This is a question very hard to answer, different people have different feeling of the difficulty. The implementation is at here. You can read it and try to estimate the difficulty by yourself.

It's mainly pytorch code, but THD format indeed has some cuda code.

@wplf
Copy link
Contributor

wplf commented Aug 28, 2024

Hi, ren, I'm back, again.
After reading ring attention thoroughly, I've got another conclusion.
If I use the AttnFuncWithCP, not the all-gather one,
and Q, K, V is in the right place, then out is in the right place. The only thing that need to note is to restore whole sequence from cp_groups in the same way as splitting. Ring attention doesn't care the order of iteration. Is that correct?

Gradly, I‘ve test the idea, which does work.
This idea passes
torchrun --nnodes 1 --nproc-per-node 8 tests/pytorch/fused_attn/run_fused_attn_with_cp.py model=cp_1_0 qkv_format=thd

Thank you very much.

@xrennvidia
Copy link
Collaborator Author

Hi, ren, I'm back, again. After reading ring attention thoroughly, I've got another conclusion. If I use the AttnFuncWithCP, not the all-gather one, and Q, K, V is in the right place, then out is in the right place. The only thing that need to note is to restore whole sequence from cp_groups in the same way as splitting. Ring attention doesn't care the order of iteration. Is that correct?

Gradly, I‘ve test the idea, which does work. This idea passes torchrun --nnodes 1 --nproc-per-node 8 tests/pytorch/fused_attn/run_fused_attn_with_cp.py model=cp_1_0 qkv_format=thd

Thank you very much.

If you do bidirectional attention (no_mask or padding_no_mask) and there is not padding tokens between each sequence, then current code should work for you even though you do not split sequence in load balancing way. Otherwise, you probably will encounter some issues.

But sounds like you already verified your idea. Glad you find the solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants