Skip to content

Conversation

@LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Jan 12, 2026

Currently, SYCLTLA FlashAttention fwd/bwd kernels only support BHSD/BSHD layout, which work well in single process scenario.

However, in distributed scenario, the batch_size/num_heads dimension's may be split by DP/TP which makes the stride not contiguous in BHSD/BSHD so that current support is not enough.

This PR use Pytorch's tensor stride to compute offset so that FA kernel could support all Pytorch stride.

Copilot AI review requested due to automatic review settings January 12, 2026 07:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the SYCLTLA FlashAttention implementation to support all PyTorch tensor strides, removing the previous BHSD/BSHD layout restrictions. This enables the kernels to work correctly in distributed scenarios where batch_size/num_heads dimensions may be split by data/tensor parallelism.

Changes:

  • Removed layout-specific code and replaced it with stride-based offset calculations using PyTorch tensor strides
  • Introduced new parameter structs (QKV_params, FLASH_FWD_params, FLASH_BWD_params) to encapsulate stride information
  • Simplified kernel interfaces by passing parameter structs instead of individual arguments

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/ATen/native/transformers/xpu/flash_attn/utils.h Removed ATTN_TENSOR_LAYOUT enum and related layout checking functions, deprecated check_flash_attention_layout
src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp Replaced layout-based logic with stride-based offset calculations, updated function signatures to use FLASH_FWD_params
src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h Added QKV_params, FLASH_FWD_params, FLASH_BWD_params structs and helper functions to populate them from tensors
src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h Updated Param struct with stride fields and offset calculation methods, removed layout-specific stride setup functions
src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp Refactored to use FLASH_BWD_params, updated offset calculations to use do_offset/dqaccum_offset, made grad_out contiguous
src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd.h Removed is_bshd parameter and layout-specific branching, simplified coordinate calculations
src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd.h Removed is_bshd parameter, eliminated BSHD-specific grid calculation path
src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_mma.h Changed Arguments to use stride fields instead of packed strides, added RuntimeParams struct, simplified offset calculations
src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_epilogue.h Updated to use stride-based offset calculations, removed is_bshd branching, separated Params and RuntimeParams

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

dropout == 0.0,
"FlashAttentionBackwardXPU does not only support dropout > 0.0 yet");

at::Tensor contiguous_grad_out = grad_out.contiguous();
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a contiguous copy of grad_out unconditionally may introduce unnecessary memory allocation and copy overhead when grad_out is already contiguous. Consider checking if grad_out is contiguous first and only creating the copy when needed, similar to pattern: at::Tensor contiguous_grad_out = grad_out.is_contiguous() ? grad_out : grad_out.contiguous();

Suggested change
at::Tensor contiguous_grad_out = grad_out.contiguous();
at::Tensor contiguous_grad_out =
grad_out.is_contiguous() ? grad_out : grad_out.contiguous();

Copilot uses AI. Check for mistakes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , an you check this and other memory hogs. I am seeing similar memory use as MATH kernel -which is not the case in cuda. See the section "Using flash attention SDP kernel (without dropout), A100 Using flash attention SDP kernel (without dropout), A100" in this blog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@intel intel deleted a comment from Copilot AI Jan 13, 2026
@LuFinch LuFinch requested a review from cfgfung January 19, 2026 06:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants