Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Jan 20, 2026

Description

Enables GQA in Attention(23) CUDA operator by reusing the existing GroupQueryAttention kernel. Follows the same pattern as PR #27030 which bridged MHA parameters.

Implementation:

  • Detect GQA when kv_num_heads != q_num_heads and route to GQA kernel path
  • Bridge 28 parameters from AttentionParameters to GroupQueryAttentionParameters
  • Initialize GroupQueryAttentionData with input/output pointers and configuration
  • Call QkvToContext<T>(device_prop, cublas, stream, parameters, data)
  • Add type validation to reject float32 for GQA path (only float16 and bfloat16 supported, following GroupQueryAttention behavior)

Current scope:

  • GQA only supports float16 and bfloat16 types (not float32)
  • Uses unfused attention path (consistent with existing MHA path)
  • Validates unsupported features: qk_matmul_output_mode (except kNone), softcap, softmax_precision, attention_bias

Testing:

  • Enabled CUDA tests for three GQA test cases:
    • Attention3DGqaAttn
    • Attention4DGqaAttnMask
    • Attention4DGqaWithPastAndPresent
  • Removed GQA test filters from onnx_backend_test_series_filters.jsonc to allow ONNX backend GQA tests to run on CUDA
  • Tests run with float16 and bfloat16 types

Backward compatibility:

  • MHA path unchanged when kv_num_heads == q_num_heads
  • Attention operator remains registered for float, MLFloat16, and BFloat16 (float uses MHA path only)

Motivation and Context

Resolves the TODO at lines 177-179 which blocked GQA usage. The GQA kernel already exists and is well-tested in GroupQueryAttention operator - this change makes it accessible via the standard Attention(23) operator for models using grouped key-value heads with float16 or bfloat16 precision.

Original prompt

This section details on the original issue you should resolve

<issue_title>Support group query attention in Attention(23) CUDA</issue_title>
<issue_description>In #27030, the PR bridges the parameters between attention_parameters.h for Attention(23) and attention_parameters.h for MultiHeadAttention to re-use its kernel.

The Attention(23) bridging parameters:

// To reuse the existing attention-cuda implementation in contrib ops,
// map the parameters to contribop_parameters.
onnxruntime::contrib::AttentionParameters contribop_parameters;
contribop_parameters.batch_size = parameters.batch_size;
contribop_parameters.sequence_length = parameters.q_sequence_length;
contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
contribop_parameters.past_sequence_length = parameters.past_sequence_length;
contribop_parameters.total_sequence_length = parameters.total_sequence_length;
// max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size)
contribop_parameters.max_sequence_length = parameters.total_sequence_length;
contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V
contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
contribop_parameters.head_size = parameters.head_size;
contribop_parameters.v_head_size = parameters.v_head_size;
contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
contribop_parameters.num_heads = parameters.q_num_heads;
contribop_parameters.rotary_dim = 0;
contribop_parameters.num_splits = 1;
contribop_parameters.beam_width = 1;
contribop_parameters.is_unidirectional = parameters.is_causal;
contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer
contribop_parameters.is_packed_qkv = false;
contribop_parameters.do_rotary = false;
// The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask
// So mask_type should always be MASK_NONE since we don't have a separate padding mask input
contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
// Determine broadcast flags for attention_bias (if it exists)
// Note: The new Attention op uses attn_mask as attention_bias
// The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length)
// attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions).
if (attn_mask != nullptr) {
// TODO(titaiwang, xadupre): attn_mask bool is not supported yet
if (attn_mask->IsDataType<bool>()) {
ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA).");
}
size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
auto attn_mask_dims = attn_mask->Shape().GetDims();
// For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
// For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting
// For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1
if (attn_mask_dims_size == 2) {
// 2D mask: both dimensions need broadcasting
contribop_parameters.broadcast_attn_bias_dim_0 = true;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else if (attn_mask_dims_size == 3) {
// 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts
contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else {
// 4D mask: check both dim 0 and dim 1 explicitly
contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1;
}
} else {
contribop_parameters.broadcast_attn_bias_dim_0 = false;
contribop_parameters.broadcast_attn_bias_dim_1 = false;
}
contribop_parameters.mask_filter_value = -10000.0f;
contribop_parameters.scale = parameters.scale;
contribop_parameters.use_tf32 = UseTF32();
// QKV format: Determine based on input dimensions
// 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
// transpose_output is true for 3D inputs, false for 4D inputs
if (!parameters.transpose_output) {
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
contribop_parameters.is_output_bnsh = true;
} else {
// 3D inputs in BSNH format (will be transposed)
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
contribop_parameters.is_output_bnsh = false;
}
// TODO(titaiwang, xadupre): Group query attention is not supported yet
if (parameters.kv_num_heads != parameters.q_num_heads) {
ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
}
// TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone &&
qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) {
ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA).");
}
// TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet
if (parameters.softcap != 0.0f) {
ORT_THROW("softcap is not supported yet in Attention op (CUDA).");
}
if (parameters.softmax_precision != 0) {
ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
}
// TODO(titaiwang): Continue on these parameters
// Construct AttentionData to pass to QkvToContext
typedef typename ToCudaType<T>::MappedType CudaT;
onnxruntime::contrib::cuda::AttentionData<CudaT> data;

The kernel reuse call:

return onnxruntime::contrib::cuda::QkvToContext<CudaT, CudaT>(

The kernel call from MultiHeadAttention originally:

return QkvToContext<CudaT, CudaQK>(

The kernel:

Similarly, we can use the same approach to bridge GQA parameters and use its kernel to address

if (parameters.kv_num_heads != parameters.q_num_heads) {
ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
}

GQA parameters:

struct GroupQueryAttentionParameters : AttentionParameters {

GQA calls the kernel:

ORT_RETURN_IF_ERROR(QkvToContext<CudaT>(

For now, we can disregard the flash attention and memory attention. Like in the case we only reuse UnfusedAttention in MultiHeadAttention

Attention spec: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Attention
MultiHeadAttention spec: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MultiHeadAttention
GroupQueryAttention spec: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.GroupQueryAttention</issue_description>

<agent_instructions>This is a good reference: #27030.
It is reusing MHA kernel, while we need now is to reuse GQA kernel.

The files of GQA:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h

Besides bridging parameters, filling the correct information to data is also important.
Like what the pr did:

// Set input pointers
data.query = reinterpret_cast<const CudaT*>(Q->Data<T>());
data.key = reinterpret_cast<const CudaT*>(K->Data<T>());
data.value = reinterpret_cast<const CudaT*>(V->Data<T>());
data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask
data.mask_index_dims = gsl::span<const int64_t>();
data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
// Set output pointers
data.output = reinterpret_cast<CudaT*>(Y->MutableData<T>());
data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
if (nullptr != output_qk) {
data.output_qk = reinterpret_cast<CudaT*>(output_qk->MutableData<T>());
}
// Set additional fields
data.bias = nullptr; // New Attention op doesn't have bias
if (nullptr != attn_mask) {
data.attention_bias = reinterpret_cast<const CudaT*>(attn_mask->Data<T>());
}
data.qkv_format = contribop_parameters.qkv_format;
// TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.)
// For now, set flags to false and let QkvToContext use the unfused path
data.use_flash_attention = false;
data.use_memory_efficient_attention = false;
data.fused_runner = nullptr;
data.fused_cross_attention_kernel = nullptr;
// Allocate workspace for Q, K, V processing and scratch buffer
const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);
size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize(
sizeof(T),
contribop_parameters.batch_size,
contribop_parameters.num_heads,
contribop_parameters.head_size,
contribop_parameters.v_head_size,
contribop_parameters.sequence_length,
contribop_parameters.kv_sequence_length,
contribop_parameters.total_sequence_length,
nullptr, // fused_runner
false, // use_flash_attention
false, // use_lean_attention
false, // use_fused_cross_attention
false, // use_memory_efficient_attention
false, // use_cudnn_flash_attention
no_qkv_workspace);
auto work_space = GetScratchBuffer<void>(workspace_bytes, context->GetComputeStream());
data.has_qkv_workspace = !no_qkv_workspace;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.workspace_bytes = workspace_bytes;
// Call QkvToContext to perform the attention computation
auto& device_prop = GetDeviceProp();
cublasHandle_t cublas = GetCublasHandle(context);
cudnnHandle_t cudnn = GetCudnnHandle(context);

The reference in GQA:

data.position_ids = (position_ids != nullptr) ? position_ids->Data<int64_t>() : nullptr;
// Input pointers for both paths
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
// Handle Past/Present pointers
data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.present_key = reinterpret_cast<CudaT*>(context->Output<Tensor>(1)->MutableData<T>());
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.present_value = reinterpret_cast<CudaT*>(context->Output<Tensor>(2)->MutableData<T>());
</agent_instructions>

Comments on the Issue (you are @copilot in this section)


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 4 commits January 20, 2026 21:56
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copilot AI changed the title [WIP] Implement support for group query attention in Attention(23) Support group query attention in Attention(23) CUDA Jan 20, 2026
Copilot AI requested a review from titaiwangms January 20, 2026 22:05
@titaiwangms
Copy link
Contributor

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

TEST(AttentionTest, Attention4DGqaAttnMask) {

TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {

Just turn disable_cuda to false

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

...

Enabled CUDA tests for all three GQA test cases in commit 2e10874:

  • Attention3DGqaAttn
  • Attention4DGqaAttnMask
  • Attention4DGqaWithPastAndPresent

@titaiwangms
Copy link
Contributor

@copilot
Also remove these disabled tests about GQA from the list:

"^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda
"^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Also remove these disabled tests about GQA from the list:
https://github.com/microsoft/onnxruntime/blob/fd21d0aa071edef70d715123b25e77891a63a5d9/onnxruntime/test/testdata/onnx_backend_test...

Removed GQA test filters from onnx_backend_test_series_filters.jsonc in commit b86acbd. The following lines were removed:

  • "^test_attention_3d_gqa.*_cuda"
  • "^test_attention_4d_gqa.*_cuda"

GQA tests will now run on CUDA.

@titaiwangms
Copy link
Contributor

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082

The error message:

LINK : warning LNK4075: ignoring '/INCREMENTAL' due to '/PROFILE' specification [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
     Creating library E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.lib and object E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.exp
attention.cc.obj : error LNK2019: unresolved external symbol "class onnxruntime::common::Status __cdecl onnxruntime::contrib::cuda::QkvToContext<float>(struct cudaDeviceProp const &,struct cublasContext * &,class onnxruntime::Stream *,struct onnxruntime::contrib::GroupQueryAttentionParameters &,struct onnxruntime::contrib::cuda::GroupQueryAttentionData<float> &)" (??$QkvToContext@M@cuda@contrib@onnxruntime@@YA?AVStatus@common@2@AEBUcudaDeviceProp@@AEAPEAUcublasContext@@PEAVStream@2@AEAUGroupQueryAttentionParameters@12@AEAU?$GroupQueryAttentionData@M@012@@Z) referenced in function "public: virtual class onnxruntime::common::Status __cdecl onnxruntime::cuda::Attention<float>::ComputeInternal(struct onnxruntime::OpKernelContext *)const " (?ComputeInternal@?$Attention@M@cuda@onnxruntime@@UEBA?AVStatus@common@3@PEAUOpKernelContext@3@@Z) [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
E:\_work\_temp\build\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_cuda.dll : fatal error LNK1120: 1 unresolved externals [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
Traceback (most recent call last):
  File "E:\_work\onnxruntime\onnxruntime\tools\ci_build\build.py", line 2684, in <module>

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@titaiwangms
Copy link
Contributor

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@copilot
Do not add float type. Just throw an error saying float is not supported with GQA. We are following GQA kernel.

REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...
...

Reverted float support for GQA in commit f79c509. Following GroupQueryAttention pattern:

  • Removed float template instantiations from group_query_attention_impl.cu
  • Added type validation in attention.cc to throw an error for float32 types when using GQA path
  • GQA now only supports float16 and bfloat16 types, matching the GroupQueryAttention operator behavior

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Jan 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support group query attention in Attention(23) CUDA

2 participants