-
Notifications
You must be signed in to change notification settings - Fork 68
[SYCLTLA] Support all Pytorch strides for FlashAttention fwd/bwd kernel #2727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR updates the SYCLTLA FlashAttention implementation to support all PyTorch tensor strides, removing the previous BHSD/BSHD layout restrictions. This enables the kernels to work correctly in distributed scenarios where batch_size/num_heads dimensions may be split by data/tensor parallelism.
Changes:
- Removed layout-specific code and replaced it with stride-based offset calculations using PyTorch tensor strides
- Introduced new parameter structs (QKV_params, FLASH_FWD_params, FLASH_BWD_params) to encapsulate stride information
- Simplified kernel interfaces by passing parameter structs instead of individual arguments
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/ATen/native/transformers/xpu/flash_attn/utils.h | Removed ATTN_TENSOR_LAYOUT enum and related layout checking functions, deprecated check_flash_attention_layout |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp | Replaced layout-based logic with stride-based offset calculations, updated function signatures to use FLASH_FWD_params |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h | Added QKV_params, FLASH_FWD_params, FLASH_BWD_params structs and helper functions to populate them from tensors |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h | Updated Param struct with stride fields and offset calculation methods, removed layout-specific stride setup functions |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp | Refactored to use FLASH_BWD_params, updated offset calculations to use do_offset/dqaccum_offset, made grad_out contiguous |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd.h | Removed is_bshd parameter and layout-specific branching, simplified coordinate calculations |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd.h | Removed is_bshd parameter, eliminated BSHD-specific grid calculation path |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_mma.h | Changed Arguments to use stride fields instead of packed strides, added RuntimeParams struct, simplified offset calculations |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_epilogue.h | Updated to use stride-based offset calculations, removed is_bshd branching, separated Params and RuntimeParams |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dropout == 0.0, | ||
| "FlashAttentionBackwardXPU does not only support dropout > 0.0 yet"); | ||
|
|
||
| at::Tensor contiguous_grad_out = grad_out.contiguous(); |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a contiguous copy of grad_out unconditionally may introduce unnecessary memory allocation and copy overhead when grad_out is already contiguous. Consider checking if grad_out is contiguous first and only creating the copy when needed, similar to pattern: at::Tensor contiguous_grad_out = grad_out.is_contiguous() ? grad_out : grad_out.contiguous();
| at::Tensor contiguous_grad_out = grad_out.contiguous(); | |
| at::Tensor contiguous_grad_out = | |
| grad_out.is_contiguous() ? grad_out : grad_out.contiguous(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, you mean this contiguous? I am aligning with CUDA https://github.com/pytorch/pytorch/blob/9f5d6ec4fe0ca9c219ac057ed1bdd62f6b759996/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L91
Currently, SYCLTLA FlashAttention fwd/bwd kernels only support BHSD/BSHD layout, which work well in single process scenario.
However, in distributed scenario, the batch_size/num_heads dimension's may be split by DP/TP which makes the stride not contiguous in BHSD/BSHD so that current support is not enough.
This PR use Pytorch's tensor stride to compute offset so that FA kernel could support all Pytorch stride.