-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Support group query attention in Attention(23) CUDA #27082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
|
@copilot Can you enable the following three tests
Just turn |
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
|
@copilot onnxruntime/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Lines 44 to 45 in fd21d0a
|
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Removed GQA test filters from
GQA tests will now run on CUDA. |
|
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082 The error message: |
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Fixed CI linker error in commit 213a82d. The issue was that the GQA |
@copilot
|
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Reverted float support for GQA in commit f79c509. Following GroupQueryAttention pattern:
|
Description
Enables GQA in Attention(23) CUDA operator by reusing the existing GroupQueryAttention kernel. Follows the same pattern as PR #27030 which bridged MHA parameters.
Implementation:
kv_num_heads != q_num_headsand route to GQA kernel pathAttentionParameterstoGroupQueryAttentionParametersGroupQueryAttentionDatawith input/output pointers and configurationQkvToContext<T>(device_prop, cublas, stream, parameters, data)Current scope:
Testing:
Attention3DGqaAttnAttention4DGqaAttnMaskAttention4DGqaWithPastAndPresentonnx_backend_test_series_filters.jsoncto allow ONNX backend GQA tests to run on CUDABackward compatibility:
kv_num_heads == q_num_headsMotivation and Context
Resolves the TODO at lines 177-179 which blocked GQA usage. The GQA kernel already exists and is well-tested in GroupQueryAttention operator - this change makes it accessible via the standard Attention(23) operator for models using grouped key-value heads with float16 or bfloat16 precision.
Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.