Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 341 additions & 21 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <vector>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/attention_helper.h"
#include "core/providers/cuda/llm/attention.h"
#include "contrib_ops/cuda/bert/attention_data.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"

using namespace onnxruntime::cuda;

Expand Down Expand Up @@ -96,8 +100,344 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* output_qk = context->Output(3, output_qk_shape);

// To reuse the existing attention-cuda implementation in contrib ops,
// map the parameters to contribop_parameters.
// map the parameters to contribop_parameters (MHA).
onnxruntime::contrib::AttentionParameters contribop_parameters;

// QKV format: Determine based on input dimensions
// 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
// transpose_output is true for 3D inputs, false for 4D inputs
if (!parameters.transpose_output) {
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
contribop_parameters.is_output_bnsh = true;
} else {
// 3D inputs in BSNH format (will be transposed)
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
contribop_parameters.is_output_bnsh = false;
}

typedef typename ToCudaType<T>::MappedType CudaT;

// Check if this is Group Query Attention (GQA)
const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads;

if (is_gqa) {
// Use GQA path with Flash Attention or Memory Efficient Attention
// GQA only supports float16 and bfloat16 types
if (std::is_same<T, float>::value) {
ORT_THROW("GQA in Attention op (CUDA) does not support float32. Please use float16 or bfloat16.");
}
// For now, GQA doesn't support qk_matmul_output_mode other than kNone
if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone) {
ORT_THROW("qk_matmul_output_mode is not supported yet in GQA path of Attention op (CUDA).");
}
// GQA doesn't support softmax_precision yet
if (parameters.softmax_precision != 0) {
ORT_THROW("softmax_precision is not supported yet in GQA path of Attention op (CUDA).");
}
// causal attention is required for GQA
if (!parameters.is_causal) {
ORT_THROW("Non-causal attention is not supported yet in GQA path of Attention op (CUDA).");
}
// GQA kernel expects K/V input sequence length == Q sequence length (self-attention only)
// Cross-attention (kv_sequence_length != q_sequence_length) is not supported
if (parameters.kv_sequence_length != parameters.q_sequence_length) {
ORT_THROW(
"Cross-attention (kv_sequence_length != q_sequence_length) is not supported in GQA path of Attention op (CUDA). "
"kv_sequence_length=",
parameters.kv_sequence_length, ", q_sequence_length=", parameters.q_sequence_length);
}

auto& device_prop = GetDeviceProp();

// Bridge parameters to GroupQueryAttentionParameters
onnxruntime::contrib::GroupQueryAttentionParameters gqa_parameters;
gqa_parameters.batch_size = parameters.batch_size;
gqa_parameters.sequence_length = parameters.q_sequence_length;
gqa_parameters.seqlen_past_kv_cache = parameters.past_sequence_length;
gqa_parameters.seqlen_present_kv_cache = parameters.total_sequence_length;
gqa_parameters.total_sequence_length = parameters.total_sequence_length;
gqa_parameters.kv_sequence_length = parameters.kv_sequence_length;
gqa_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
gqa_parameters.num_heads = parameters.q_num_heads;
gqa_parameters.head_size = parameters.head_size;
gqa_parameters.v_head_size = parameters.v_head_size;
gqa_parameters.kv_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
gqa_parameters.kv_num_heads = parameters.kv_num_heads;
gqa_parameters.scale = parameters.scale;
gqa_parameters.softcap = parameters.softcap;
gqa_parameters.qkv_format = contribop_parameters.qkv_format;

// Unset or set to default values for GQA-specific fields
gqa_parameters.rotary_dim = 0; // New Attention op doesn't use rotary embeddings directly
gqa_parameters.is_unidirectional = true; // GQA requires causal attention
gqa_parameters.is_packed_qkv = false; // New Attention op has separate Q, K, V inputs
gqa_parameters.is_subsequent_prompt = false;
gqa_parameters.is_first_prompt = parameters.past_sequence_length == 0;
gqa_parameters.do_rotary = false; // New Attention op doesn't use rotary embeddings
gqa_parameters.rotary_interleaved = false;
gqa_parameters.use_smooth_softmax = false;
gqa_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
gqa_parameters.past_kv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
gqa_parameters.local_window_size = -1; // No local window for standard attention
gqa_parameters.zeros_count = 0;
gqa_parameters.zero_ptr = nullptr;
gqa_parameters.num_splits = 1;

// Construct GroupQueryAttentionData
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT> gqa_data;

// Scratch buffers for flash/memory efficient attention
IAllocatorUniquePtr<void> k_buffer;
IAllocatorUniquePtr<void> v_buffer;
IAllocatorUniquePtr<void> fmha_buffer;
IAllocatorUniquePtr<void> unpacked_qkv_buffer;
IAllocatorUniquePtr<int> seq_lens_buffer;
IAllocatorUniquePtr<int> seqlens_k_buffer;

// Present KV cache buffers - GQA kernel uses these as working buffers
// If outputs are not provided, we allocate scratch buffers
IAllocatorUniquePtr<void> present_key_scratch;
IAllocatorUniquePtr<void> present_value_scratch;

// Set input pointers
gqa_data.query = reinterpret_cast<const CudaT*>(Q->Data<T>());
gqa_data.key = reinterpret_cast<const CudaT*>(K->Data<T>());
gqa_data.value = reinterpret_cast<const CudaT*>(V->Data<T>());
gqa_data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
gqa_data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());

// Set output pointers
gqa_data.output = reinterpret_cast<CudaT*>(Y->MutableData<T>());

// GQA kernel requires present_key/present_value buffers as working storage for KV cache
// Allocate scratch buffers if outputs are not provided
size_t present_kv_size = static_cast<size_t>(parameters.batch_size) *
static_cast<size_t>(parameters.kv_num_heads) *
static_cast<size_t>(parameters.total_sequence_length) *
static_cast<size_t>(parameters.head_size) * sizeof(CudaT);
if (present_key != nullptr) {
gqa_data.present_key = reinterpret_cast<CudaT*>(present_key->MutableData<T>());
} else {
present_key_scratch = GetScratchBuffer<void>(present_kv_size, context->GetComputeStream());
gqa_data.present_key = reinterpret_cast<CudaT*>(present_key_scratch.get());
}
if (present_value != nullptr) {
gqa_data.present_value = reinterpret_cast<CudaT*>(present_value->MutableData<T>());
} else {
present_value_scratch = GetScratchBuffer<void>(present_kv_size, context->GetComputeStream());
gqa_data.present_value = reinterpret_cast<CudaT*>(present_value_scratch.get());
}

// Compute past_present_share_buffer early since it's needed for flash attention path selection
gqa_parameters.past_present_share_buffer = (gqa_data.past_key == gqa_data.present_key);

// Flash Attention buffers
IAllocatorUniquePtr<void> softmax_lse_buffer;
IAllocatorUniquePtr<void> softmax_lse_accum_buffer;
IAllocatorUniquePtr<void> out_accum_buffer;

// Check Flash Attention support
#if USE_FLASH_ATTENTION
bool use_flash_attention = onnxruntime::flash::is_supported<T>(device_prop,
gqa_parameters.head_size,
gqa_parameters.num_heads,
gqa_parameters.kv_num_heads);

gqa_data.use_flash_attention = use_flash_attention;
gqa_data.use_flash_attention_fast_decode = use_flash_attention &&
!gqa_parameters.is_first_prompt &&
gqa_parameters.past_present_share_buffer;

if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(
gqa_parameters.sequence_length, gqa_parameters.batch_size, gqa_parameters.num_heads);

int num_heads_for_split = gqa_data.use_flash_attention_fast_decode
? gqa_parameters.kv_num_heads
: gqa_parameters.num_heads;
auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] =
onnxruntime::flash::get_num_splits_and_buffer_sizes(
gqa_parameters.batch_size, gqa_parameters.sequence_length,
gqa_parameters.total_sequence_length, num_heads_for_split,
gqa_parameters.head_size, device_prop.multiProcessorCount);

gqa_parameters.num_splits = static_cast<int>(num_splits);

if (gqa_data.use_flash_attention_fast_decode && num_splits > 1) {
// The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
// However, the LSE and Accum buffers must store results for ALL num_heads.
softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(
num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length);
auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; };
out_accum_bytes = onnxruntime::flash::get_out_accum_size(
num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length,
round_multiple(gqa_parameters.head_size, 32));
}

softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());

gqa_data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
gqa_data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
gqa_data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
} else {
gqa_data.softmax_lse = nullptr;
gqa_data.softmax_lse_accum = nullptr;
gqa_data.out_accum = nullptr;
}
#else
gqa_data.use_flash_attention = false;
gqa_data.use_flash_attention_fast_decode = false;
gqa_data.softmax_lse = nullptr;
gqa_data.softmax_lse_accum = nullptr;
gqa_data.out_accum = nullptr;
#endif

// Check Memory Efficient Attention support (fallback if flash attention not available)
#if USE_MEMORY_EFFICIENT_ATTENTION
if (!gqa_data.use_flash_attention) {
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
onnxruntime::contrib::cuda::has_memory_efficient_attention(
sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value,
gqa_parameters.head_size, gqa_parameters.head_size);
gqa_data.use_memory_efficient_attention = use_memory_efficient_attention;

// KV buffer for head expansion (when num_heads != kv_num_heads)
size_t kv_buffer_bytes = (use_memory_efficient_attention &&
(gqa_parameters.num_heads != gqa_parameters.kv_num_heads))
? (sizeof(T) * gqa_parameters.batch_size * gqa_parameters.num_heads *
gqa_parameters.seqlen_present_kv_cache * gqa_parameters.head_size)
: 0;
// FMHA workspace
size_t fmha_buffer_bytes =
(use_memory_efficient_attention &&
onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace(
gqa_parameters.head_size, sizeof(T) == sizeof(float)))
? (sizeof(float) * gqa_parameters.batch_size * gqa_parameters.sequence_length *
gqa_parameters.num_heads * gqa_parameters.head_size)
: 0;

k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());

gqa_data.k = reinterpret_cast<CudaT*>(k_buffer.get());
gqa_data.v = reinterpret_cast<CudaT*>(v_buffer.get());
gqa_data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
} else {
gqa_data.use_memory_efficient_attention = false;
gqa_data.k = nullptr;
gqa_data.v = nullptr;
gqa_data.fmha_buffer = nullptr;
}
#else
gqa_data.use_memory_efficient_attention = false;
gqa_data.k = nullptr;
gqa_data.v = nullptr;
gqa_data.fmha_buffer = nullptr;
#endif

// Centralized scratch buffer allocation using GQABufferRequirements
auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute<T>(
gqa_parameters,
gqa_data.use_flash_attention,
gqa_data.use_flash_attention_fast_decode,
gqa_data.use_memory_efficient_attention);

if (buffer_req.qkv_buffer_bytes > 0) {
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
gqa_data.qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
} else {
gqa_data.qkv_buffer = nullptr;
}

// Allocate CPU buffer for seqlens_k (total_sequence_length - 1) for GQA compatibility
// The GQA kernel expects sequence length information for flash/memory efficient attention
// We need a CPU buffer first, then copy to GPU
std::vector<int> seqlens_k_host(parameters.batch_size);

// GQA only supports masking, not additive bias.
// For bool mask, we need to convert it to sequence lengths.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
const bool* b_mask = attn_mask->Data<bool>();

for (int b = 0; b < parameters.batch_size; ++b) {
const bool* row = b_mask + b * parameters.total_sequence_length;
int seq_len = 0;

// Find the actual sequence length by looking for the last valid (true) position
// Mask convention per Attention spec: true = valid (should participate), false = masked out
for (int i = parameters.total_sequence_length - 1; i >= 0; --i) {
if (row[i]) {
seq_len = i + 1;
break;
}
}
// seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention)
seqlens_k_host[b] = seq_len - 1;
}
} else if (attn_mask != nullptr) {
ORT_THROW("Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA).");
} else {
// No mask provided - use full sequence length for all batches
// seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention)
for (int b = 0; b < parameters.batch_size; ++b) {
seqlens_k_host[b] = parameters.total_sequence_length - 1;
}
}

// Copy seqlens_k to GPU
seqlens_k_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(seqlens_k_buffer.get(), seqlens_k_host.data(),
sizeof(int) * parameters.batch_size,
cudaMemcpyHostToDevice, cuda_stream));

// Process seqlens_k to compute past_seq_lens, total_seq_lens, and padded_seq_lens
// This is always needed for flash/memory efficient attention
seq_lens_buffer = GetScratchBuffer<int>(3 * parameters.batch_size, context->GetComputeStream());
gqa_data.past_seq_lens = seq_lens_buffer.get();
gqa_data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size;
gqa_data.padded_seq_lens = gqa_data.total_seq_lens + parameters.batch_size;

ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchGetSequenceLengths(
seqlens_k_buffer.get(),
gqa_data.past_seq_lens,
gqa_data.total_seq_lens,
gqa_data.padded_seq_lens,
parameters.batch_size,
parameters.q_sequence_length,
gqa_parameters.is_first_prompt,
cuda_stream,
device_prop.maxThreadsPerBlock));

// Set GQA-specific fields
gqa_data.cos_cache = nullptr; // No rotary embeddings
gqa_data.sin_cache = nullptr;
gqa_data.head_sink = nullptr;
gqa_data.position_ids = nullptr;

#ifndef NDEBUG
// Initialize debug tracking fields
gqa_data.unpacked_qkv_buffer_size = 0;
gqa_data.rotary_buffer_size = 0;
gqa_data.position_ids_buffer_size = 0;
gqa_data.unpacked_qkv_max_used = 0;
gqa_data.rotary_max_used = 0;
gqa_data.position_ids_max_used = 0;
#endif

// Call GQA kernel (with flash or memory efficient attention)
cublasHandle_t cublas = GetCublasHandle(context);

return onnxruntime::contrib::cuda::QkvToContext<CudaT>(
device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data);
}

// MHA path (kv_num_heads == q_num_heads)
contribop_parameters.batch_size = parameters.batch_size;
contribop_parameters.sequence_length = parameters.q_sequence_length;
contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
Expand Down Expand Up @@ -160,24 +500,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
contribop_parameters.mask_filter_value = -10000.0f;
contribop_parameters.scale = parameters.scale;
contribop_parameters.use_tf32 = UseTF32();

// QKV format: Determine based on input dimensions
// 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
// transpose_output is true for 3D inputs, false for 4D inputs
if (!parameters.transpose_output) {
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
contribop_parameters.is_output_bnsh = true;
} else {
// 3D inputs in BSNH format (will be transposed)
contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
contribop_parameters.is_output_bnsh = false;
}

// TODO(titaiwang, xadupre): Group query attention is not supported yet
if (parameters.kv_num_heads != parameters.q_num_heads) {
ORT_THROW("Group query attention is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone &&
qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) {
Expand All @@ -191,9 +513,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang): Continue on these parameters
// Construct AttentionData to pass to QkvToContext
typedef typename ToCudaType<T>::MappedType CudaT;
onnxruntime::contrib::cuda::AttentionData<CudaT> data;

// Set input pointers
Expand Down
Loading
Loading