diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 99f297bba6444..3fc64f07690dd 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,11 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention_helper.h" #include "core/providers/cuda/llm/attention.h" #include "contrib_ops/cuda/bert/attention_data.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; @@ -96,8 +100,344 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { Tensor* output_qk = context->Output(3, output_qk_shape); // To reuse the existing attention-cuda implementation in contrib ops, - // map the parameters to contribop_parameters. + // map the parameters to contribop_parameters (MHA). onnxruntime::contrib::AttentionParameters contribop_parameters; + + // QKV format: Determine based on input dimensions + // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH + // transpose_output is true for 3D inputs, false for 4D inputs + if (!parameters.transpose_output) { + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; + contribop_parameters.is_output_bnsh = true; + } else { + // 3D inputs in BSNH format (will be transposed) + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; + contribop_parameters.is_output_bnsh = false; + } + + typedef typename ToCudaType::MappedType CudaT; + + // Check if this is Group Query Attention (GQA) + const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; + + if (is_gqa) { + // Use GQA path with Flash Attention or Memory Efficient Attention + // GQA only supports float16 and bfloat16 types + if (std::is_same::value) { + ORT_THROW("GQA in Attention op (CUDA) does not support float32. Please use float16 or bfloat16."); + } + // For now, GQA doesn't support qk_matmul_output_mode other than kNone + if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone) { + ORT_THROW("qk_matmul_output_mode is not supported yet in GQA path of Attention op (CUDA)."); + } + // GQA doesn't support softmax_precision yet + if (parameters.softmax_precision != 0) { + ORT_THROW("softmax_precision is not supported yet in GQA path of Attention op (CUDA)."); + } + // causal attention is required for GQA + if (!parameters.is_causal) { + ORT_THROW("Non-causal attention is not supported yet in GQA path of Attention op (CUDA)."); + } + // GQA kernel expects K/V input sequence length == Q sequence length (self-attention only) + // Cross-attention (kv_sequence_length != q_sequence_length) is not supported + if (parameters.kv_sequence_length != parameters.q_sequence_length) { + ORT_THROW( + "Cross-attention (kv_sequence_length != q_sequence_length) is not supported in GQA path of Attention op (CUDA). " + "kv_sequence_length=", + parameters.kv_sequence_length, ", q_sequence_length=", parameters.q_sequence_length); + } + + auto& device_prop = GetDeviceProp(); + + // Bridge parameters to GroupQueryAttentionParameters + onnxruntime::contrib::GroupQueryAttentionParameters gqa_parameters; + gqa_parameters.batch_size = parameters.batch_size; + gqa_parameters.sequence_length = parameters.q_sequence_length; + gqa_parameters.seqlen_past_kv_cache = parameters.past_sequence_length; + gqa_parameters.seqlen_present_kv_cache = parameters.total_sequence_length; + gqa_parameters.total_sequence_length = parameters.total_sequence_length; + gqa_parameters.kv_sequence_length = parameters.kv_sequence_length; + gqa_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; + gqa_parameters.num_heads = parameters.q_num_heads; + gqa_parameters.head_size = parameters.head_size; + gqa_parameters.v_head_size = parameters.v_head_size; + gqa_parameters.kv_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + gqa_parameters.kv_num_heads = parameters.kv_num_heads; + gqa_parameters.scale = parameters.scale; + gqa_parameters.softcap = parameters.softcap; + gqa_parameters.qkv_format = contribop_parameters.qkv_format; + + // Unset or set to default values for GQA-specific fields + gqa_parameters.rotary_dim = 0; // New Attention op doesn't use rotary embeddings directly + gqa_parameters.is_unidirectional = true; // GQA requires causal attention + gqa_parameters.is_packed_qkv = false; // New Attention op has separate Q, K, V inputs + gqa_parameters.is_subsequent_prompt = false; + gqa_parameters.is_first_prompt = parameters.past_sequence_length == 0; + gqa_parameters.do_rotary = false; // New Attention op doesn't use rotary embeddings + gqa_parameters.rotary_interleaved = false; + gqa_parameters.use_smooth_softmax = false; + gqa_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; + gqa_parameters.past_kv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; + gqa_parameters.local_window_size = -1; // No local window for standard attention + gqa_parameters.zeros_count = 0; + gqa_parameters.zero_ptr = nullptr; + gqa_parameters.num_splits = 1; + + // Construct GroupQueryAttentionData + onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; + + // Scratch buffers for flash/memory efficient attention + IAllocatorUniquePtr k_buffer; + IAllocatorUniquePtr v_buffer; + IAllocatorUniquePtr fmha_buffer; + IAllocatorUniquePtr unpacked_qkv_buffer; + IAllocatorUniquePtr seq_lens_buffer; + IAllocatorUniquePtr seqlens_k_buffer; + + // Present KV cache buffers - GQA kernel uses these as working buffers + // If outputs are not provided, we allocate scratch buffers + IAllocatorUniquePtr present_key_scratch; + IAllocatorUniquePtr present_value_scratch; + + // Set input pointers + gqa_data.query = reinterpret_cast(Q->Data()); + gqa_data.key = reinterpret_cast(K->Data()); + gqa_data.value = reinterpret_cast(V->Data()); + gqa_data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + gqa_data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + + // Set output pointers + gqa_data.output = reinterpret_cast(Y->MutableData()); + + // GQA kernel requires present_key/present_value buffers as working storage for KV cache + // Allocate scratch buffers if outputs are not provided + size_t present_kv_size = static_cast(parameters.batch_size) * + static_cast(parameters.kv_num_heads) * + static_cast(parameters.total_sequence_length) * + static_cast(parameters.head_size) * sizeof(CudaT); + if (present_key != nullptr) { + gqa_data.present_key = reinterpret_cast(present_key->MutableData()); + } else { + present_key_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream()); + gqa_data.present_key = reinterpret_cast(present_key_scratch.get()); + } + if (present_value != nullptr) { + gqa_data.present_value = reinterpret_cast(present_value->MutableData()); + } else { + present_value_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream()); + gqa_data.present_value = reinterpret_cast(present_value_scratch.get()); + } + + // Compute past_present_share_buffer early since it's needed for flash attention path selection + gqa_parameters.past_present_share_buffer = (gqa_data.past_key == gqa_data.present_key); + + // Flash Attention buffers + IAllocatorUniquePtr softmax_lse_buffer; + IAllocatorUniquePtr softmax_lse_accum_buffer; + IAllocatorUniquePtr out_accum_buffer; + + // Check Flash Attention support +#if USE_FLASH_ATTENTION + bool use_flash_attention = onnxruntime::flash::is_supported(device_prop, + gqa_parameters.head_size, + gqa_parameters.num_heads, + gqa_parameters.kv_num_heads); + + gqa_data.use_flash_attention = use_flash_attention; + gqa_data.use_flash_attention_fast_decode = use_flash_attention && + !gqa_parameters.is_first_prompt && + gqa_parameters.past_present_share_buffer; + + if (use_flash_attention) { + // Allocate Flash specific buffers (Softmax LSE, Accum) + size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size( + gqa_parameters.sequence_length, gqa_parameters.batch_size, gqa_parameters.num_heads); + + int num_heads_for_split = gqa_data.use_flash_attention_fast_decode + ? gqa_parameters.kv_num_heads + : gqa_parameters.num_heads; + auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = + onnxruntime::flash::get_num_splits_and_buffer_sizes( + gqa_parameters.batch_size, gqa_parameters.sequence_length, + gqa_parameters.total_sequence_length, num_heads_for_split, + gqa_parameters.head_size, device_prop.multiProcessorCount); + + gqa_parameters.num_splits = static_cast(num_splits); + + if (gqa_data.use_flash_attention_fast_decode && num_splits > 1) { + // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel. + // However, the LSE and Accum buffers must store results for ALL num_heads. + softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( + num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + out_accum_bytes = onnxruntime::flash::get_out_accum_size( + num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length, + round_multiple(gqa_parameters.head_size, 32)); + } + + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + + gqa_data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + gqa_data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + gqa_data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } else { + gqa_data.softmax_lse = nullptr; + gqa_data.softmax_lse_accum = nullptr; + gqa_data.out_accum = nullptr; + } +#else + gqa_data.use_flash_attention = false; + gqa_data.use_flash_attention_fast_decode = false; + gqa_data.softmax_lse = nullptr; + gqa_data.softmax_lse_accum = nullptr; + gqa_data.out_accum = nullptr; +#endif + + // Check Memory Efficient Attention support (fallback if flash attention not available) +#if USE_MEMORY_EFFICIENT_ATTENTION + if (!gqa_data.use_flash_attention) { + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + onnxruntime::contrib::cuda::has_memory_efficient_attention( + sm, std::is_same::value, std::is_same::value, + gqa_parameters.head_size, gqa_parameters.head_size); + gqa_data.use_memory_efficient_attention = use_memory_efficient_attention; + + // KV buffer for head expansion (when num_heads != kv_num_heads) + size_t kv_buffer_bytes = (use_memory_efficient_attention && + (gqa_parameters.num_heads != gqa_parameters.kv_num_heads)) + ? (sizeof(T) * gqa_parameters.batch_size * gqa_parameters.num_heads * + gqa_parameters.seqlen_present_kv_cache * gqa_parameters.head_size) + : 0; + // FMHA workspace + size_t fmha_buffer_bytes = + (use_memory_efficient_attention && + onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace( + gqa_parameters.head_size, sizeof(T) == sizeof(float))) + ? (sizeof(float) * gqa_parameters.batch_size * gqa_parameters.sequence_length * + gqa_parameters.num_heads * gqa_parameters.head_size) + : 0; + + k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + + gqa_data.k = reinterpret_cast(k_buffer.get()); + gqa_data.v = reinterpret_cast(v_buffer.get()); + gqa_data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } else { + gqa_data.use_memory_efficient_attention = false; + gqa_data.k = nullptr; + gqa_data.v = nullptr; + gqa_data.fmha_buffer = nullptr; + } +#else + gqa_data.use_memory_efficient_attention = false; + gqa_data.k = nullptr; + gqa_data.v = nullptr; + gqa_data.fmha_buffer = nullptr; +#endif + + // Centralized scratch buffer allocation using GQABufferRequirements + auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute( + gqa_parameters, + gqa_data.use_flash_attention, + gqa_data.use_flash_attention_fast_decode, + gqa_data.use_memory_efficient_attention); + + if (buffer_req.qkv_buffer_bytes > 0) { + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); + gqa_data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); + } else { + gqa_data.qkv_buffer = nullptr; + } + + // Allocate CPU buffer for seqlens_k (total_sequence_length - 1) for GQA compatibility + // The GQA kernel expects sequence length information for flash/memory efficient attention + // We need a CPU buffer first, then copy to GPU + std::vector seqlens_k_host(parameters.batch_size); + + // GQA only supports masking, not additive bias. + // For bool mask, we need to convert it to sequence lengths. + if (attn_mask != nullptr && attn_mask->IsDataType()) { + const bool* b_mask = attn_mask->Data(); + + for (int b = 0; b < parameters.batch_size; ++b) { + const bool* row = b_mask + b * parameters.total_sequence_length; + int seq_len = 0; + + // Find the actual sequence length by looking for the last valid (true) position + // Mask convention per Attention spec: true = valid (should participate), false = masked out + for (int i = parameters.total_sequence_length - 1; i >= 0; --i) { + if (row[i]) { + seq_len = i + 1; + break; + } + } + // seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention) + seqlens_k_host[b] = seq_len - 1; + } + } else if (attn_mask != nullptr) { + ORT_THROW("Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA)."); + } else { + // No mask provided - use full sequence length for all batches + // seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention) + for (int b = 0; b < parameters.batch_size; ++b) { + seqlens_k_host[b] = parameters.total_sequence_length - 1; + } + } + + // Copy seqlens_k to GPU + seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(seqlens_k_buffer.get(), seqlens_k_host.data(), + sizeof(int) * parameters.batch_size, + cudaMemcpyHostToDevice, cuda_stream)); + + // Process seqlens_k to compute past_seq_lens, total_seq_lens, and padded_seq_lens + // This is always needed for flash/memory efficient attention + seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream()); + gqa_data.past_seq_lens = seq_lens_buffer.get(); + gqa_data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size; + gqa_data.padded_seq_lens = gqa_data.total_seq_lens + parameters.batch_size; + + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchGetSequenceLengths( + seqlens_k_buffer.get(), + gqa_data.past_seq_lens, + gqa_data.total_seq_lens, + gqa_data.padded_seq_lens, + parameters.batch_size, + parameters.q_sequence_length, + gqa_parameters.is_first_prompt, + cuda_stream, + device_prop.maxThreadsPerBlock)); + + // Set GQA-specific fields + gqa_data.cos_cache = nullptr; // No rotary embeddings + gqa_data.sin_cache = nullptr; + gqa_data.head_sink = nullptr; + gqa_data.position_ids = nullptr; + +#ifndef NDEBUG + // Initialize debug tracking fields + gqa_data.unpacked_qkv_buffer_size = 0; + gqa_data.rotary_buffer_size = 0; + gqa_data.position_ids_buffer_size = 0; + gqa_data.unpacked_qkv_max_used = 0; + gqa_data.rotary_max_used = 0; + gqa_data.position_ids_max_used = 0; +#endif + + // Call GQA kernel (with flash or memory efficient attention) + cublasHandle_t cublas = GetCublasHandle(context); + + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data); + } + + // MHA path (kv_num_heads == q_num_heads) contribop_parameters.batch_size = parameters.batch_size; contribop_parameters.sequence_length = parameters.q_sequence_length; contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; @@ -160,24 +500,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { contribop_parameters.mask_filter_value = -10000.0f; contribop_parameters.scale = parameters.scale; contribop_parameters.use_tf32 = UseTF32(); - - // QKV format: Determine based on input dimensions - // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH - // transpose_output is true for 3D inputs, false for 4D inputs - if (!parameters.transpose_output) { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - contribop_parameters.is_output_bnsh = true; - } else { - // 3D inputs in BSNH format (will be transposed) - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; - contribop_parameters.is_output_bnsh = false; - } - - // TODO(titaiwang, xadupre): Group query attention is not supported yet - if (parameters.kv_num_heads != parameters.q_num_heads) { - ORT_THROW("Group query attention is not supported yet in Attention op (CUDA)."); - } - // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { @@ -191,9 +513,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); } - // TODO(titaiwang): Continue on these parameters // Construct AttentionData to pass to QkvToContext - typedef typename ToCudaType::MappedType CudaT; onnxruntime::contrib::cuda::AttentionData data; // Set input pointers diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 5c5c2efdb50cd..50a7549b300e6 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -895,6 +895,41 @@ TEST(AttentionTest, Attention3DGqaAttn) { ); } +// GQA kernel only supports causal and fp16/bf16 self-attention +// This is a self-attention test where q_sequence_length == kv_sequence_length +TEST(AttentionTest, Attention3DGqaSelfAttnCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 6; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 6, 72} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 6, 24} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 6, 24} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {2, 6, 72} + std::vector y = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.375184f, 0.915488f, 0.291794f, 0.369500f, 0.664971f, 0.529153f, 0.762684f, 0.338902f, 0.397444f, 0.920207f, 0.277550f, 0.358363f, 0.685605f, 0.525467f, 0.776164f, 0.323729f, 0.414434f, 0.923809f, 0.266678f, 0.349863f, 0.701353f, 0.522654f, 0.786453f, 0.312148f, 0.689214f, 0.467515f, 0.356314f, 0.432734f, 0.610069f, 0.353058f, 0.585354f, 0.416724f, 0.674467f, 0.454209f, 0.352567f, 0.434067f, 0.593396f, 0.353462f, 0.616405f, 0.446400f, 0.692585f, 0.470557f, 0.357171f, 0.432430f, 0.613881f, 0.352965f, 0.578255f, 0.409939f, 0.804387f, 0.615237f, 0.340902f, 0.549039f, 0.474621f, 0.226210f, 0.837629f, 0.384495f, 0.803948f, 0.613379f, 0.342255f, 0.550354f, 0.476454f, 0.226886f, 0.837247f, 0.386503f, 0.803713f, 0.612385f, 0.342978f, 0.551057f, 0.477434f, 0.227248f, 0.837044f, 0.387576f, 0.625910f, 0.606317f, 0.498135f, 0.435228f, 0.569180f, 0.454523f, 0.772291f, 0.316493f, 0.600102f, 0.640974f, 0.474028f, 0.426968f, 0.581606f, 0.462464f, 0.772414f, 0.317639f, 0.617036f, 0.602411f, 0.505863f, 0.440737f, 0.559120f, 0.455807f, 0.766122f, 0.323372f, 0.678009f, 0.485463f, 0.402291f, 0.599954f, 0.590488f, 0.522209f, 0.635031f, 0.281346f, 0.675364f, 0.483282f, 0.401978f, 0.601415f, 0.587446f, 0.523522f, 0.640793f, 0.285508f, 0.690747f, 0.497418f, 0.406332f, 0.601542f, 0.604776f, 0.524641f, 0.608640f, 0.253110f, 0.792985f, 0.471221f, 0.543753f, 0.367260f, 0.424255f, 0.180896f, 0.707169f, 0.541938f, 0.793076f, 0.469790f, 0.546655f, 0.362887f, 0.422014f, 0.179562f, 0.704959f, 0.543519f, 0.791576f, 0.458557f, 0.559829f, 0.356372f, 0.423285f, 0.178667f, 0.697510f, 0.555749f, 0.532676f, 0.522498f, 0.450644f, 0.518150f, 0.475008f, 0.408996f, 0.611546f, 0.408014f, 0.531107f, 0.533903f, 0.456242f, 0.512505f, 0.479370f, 0.415609f, 0.622785f, 0.404108f, 0.517890f, 0.501334f, 0.452817f, 0.535428f, 0.452436f, 0.401614f, 0.583764f, 0.425736f, 0.614342f, 0.446003f, 0.478324f, 0.524041f, 0.631157f, 0.583251f, 0.518747f, 0.449821f, 0.603049f, 0.441193f, 0.492507f, 0.533154f, 0.631851f, 0.609747f, 0.513797f, 0.460154f, 0.618372f, 0.452224f, 0.479473f, 0.552495f, 0.626672f, 0.602204f, 0.531833f, 0.417003f, 0.677216f, 0.450390f, 0.460556f, 0.372157f, 0.434257f, 0.245953f, 0.743416f, 0.592728f, 0.668075f, 0.458899f, 0.439168f, 0.384565f, 0.438028f, 0.254483f, 0.755536f, 0.585862f, 0.656018f, 0.426196f, 0.475801f, 0.346222f, 0.428263f, 0.250239f, 0.730597f, 0.623930f, 0.462562f, 0.526164f, 0.389326f, 0.517561f, 0.427406f, 0.385959f, 0.569898f, 0.484484f, 0.474005f, 0.552637f, 0.376769f, 0.497320f, 0.452783f, 0.391908f, 0.596991f, 0.468663f, 0.445596f, 0.523559f, 0.380551f, 0.526076f, 0.418282f, 0.382977f, 0.551092f, 0.495460f, 0.498820f, 0.528269f, 0.528297f, 0.467093f, 0.682089f, 0.576168f, 0.637209f, 0.401392f, 0.490348f, 0.530672f, 0.529194f, 0.457350f, 0.683798f, 0.569488f, 0.646106f, 0.407858f, 0.480406f, 0.531174f, 0.536769f, 0.462804f, 0.683767f, 0.582784f, 0.653107f, 0.409162f, 0.622447f, 0.402091f, 0.428179f, 0.404020f, 0.529063f, 0.365775f, 0.737918f, 0.611661f, 0.618995f, 0.409366f, 0.412110f, 0.413444f, 0.535286f, 0.376103f, 0.745308f, 0.601499f, 0.646685f, 0.416543f, 0.436797f, 0.411064f, 0.518753f, 0.340880f, 0.740302f, 0.597846f, 0.526615f, 0.537294f, 0.446052f, 0.549804f, 0.413170f, 0.440843f, 0.508644f, 0.454324f, 0.531923f, 0.515024f, 0.456002f, 0.556102f, 0.399799f, 0.432958f, 0.504235f, 0.462042f, 0.531512f, 0.514512f, 0.448893f, 0.554885f, 0.405433f, 0.429084f, 0.505929f, 0.459431f, 0.484278f, 0.523931f, 0.553328f, 0.421070f, 0.663928f, 0.625415f, 0.547914f, 0.427544f, 0.463403f, 0.530154f, 0.568003f, 0.414951f, 0.673562f, 0.635938f, 0.556527f, 0.433707f, 0.471436f, 0.530938f, 0.563255f, 0.417469f, 0.676073f, 0.628029f, 0.555403f, 0.427696f, 0.624993f, 0.347656f, 0.428891f, 0.468551f, 0.456436f, 0.447665f, 0.709757f, 0.507535f, 0.613676f, 0.360608f, 0.408412f, 0.466924f, 0.454163f, 0.445416f, 0.722256f, 0.508416f, 0.618408f, 0.338037f, 0.443792f, 0.444639f, 0.461626f, 0.441243f, 0.700985f, 0.534042f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.564179f, 0.958988f, 0.647053f, 0.222007f, 0.271883f, 0.366312f, 0.071173f, 0.465545f, 0.525511f, 0.958384f, 0.639686f, 0.208972f, 0.273665f, 0.376599f, 0.065457f, 0.459408f, 0.511507f, 0.958165f, 0.637018f, 0.204251f, 0.274311f, 0.380325f, 0.063387f, 0.457185f, 0.313102f, 0.280737f, 0.607950f, 0.740205f, 0.128510f, 0.523919f, 0.645965f, 0.900198f, 0.315797f, 0.280171f, 0.607655f, 0.740404f, 0.130150f, 0.525395f, 0.645868f, 0.900936f, 0.312634f, 0.280836f, 0.608001f, 0.740170f, 0.128225f, 0.523663f, 0.645982f, 0.900070f, 0.589133f, 0.923803f, 0.350077f, 0.078063f, 0.676116f, 0.208496f, 0.691386f, 0.170909f, 0.576824f, 0.928713f, 0.340944f, 0.071081f, 0.655312f, 0.213646f, 0.705698f, 0.177599f, 0.578667f, 0.927978f, 0.342312f, 0.072127f, 0.658427f, 0.212874f, 0.703555f, 0.176598f, 0.461675f, 0.908043f, 0.662587f, 0.314431f, 0.403545f, 0.389025f, 0.047470f, 0.324577f, 0.446757f, 0.908095f, 0.659543f, 0.308669f, 0.403499f, 0.393020f, 0.045315f, 0.322930f, 0.465259f, 0.913913f, 0.659139f, 0.300679f, 0.388424f, 0.388608f, 0.049008f, 0.339838f, 0.393977f, 0.206322f, 0.401871f, 0.823817f, 0.402496f, 0.391177f, 0.407877f, 0.917636f, 0.380842f, 0.214048f, 0.420364f, 0.816133f, 0.375065f, 0.399293f, 0.428687f, 0.914445f, 0.395250f, 0.200438f, 0.382438f, 0.831506f, 0.425258f, 0.374556f, 0.384825f, 0.917522f, 0.512971f, 0.734163f, 0.439597f, 0.087740f, 0.576454f, 0.402861f, 0.674433f, 0.258468f, 0.501254f, 0.716540f, 0.445702f, 0.086453f, 0.559595f, 0.424232f, 0.677365f, 0.269516f, 0.499800f, 0.737298f, 0.431231f, 0.080779f, 0.554471f, 0.409937f, 0.688732f, 0.266073f, 0.462157f, 0.870872f, 0.546805f, 0.488834f, 0.525958f, 0.438242f, 0.047659f, 0.400733f, 0.478201f, 0.880733f, 0.536080f, 0.473352f, 0.501806f, 0.437440f, 0.052016f, 0.435873f, 0.477773f, 0.879052f, 0.549043f, 0.470968f, 0.503898f, 0.433002f, 0.051116f, 0.417409f, 0.515061f, 0.284437f, 0.392020f, 0.653122f, 0.483621f, 0.509133f, 0.335149f, 0.878162f, 0.523926f, 0.283378f, 0.380616f, 0.651208f, 0.501330f, 0.506441f, 0.319676f, 0.877617f, 0.518601f, 0.284137f, 0.393970f, 0.652876f, 0.483085f, 0.513057f, 0.337976f, 0.879356f, 0.370769f, 0.773959f, 0.503227f, 0.221362f, 0.495027f, 0.570949f, 0.552470f, 0.424992f, 0.389319f, 0.785114f, 0.496812f, 0.216956f, 0.516126f, 0.544699f, 0.553702f, 0.406750f, 0.374517f, 0.800049f, 0.487073f, 0.216024f, 0.497095f, 0.548627f, 0.562448f, 0.416828f, 0.490503f, 0.875994f, 0.553352f, 0.370920f, 0.474758f, 0.438421f, 0.062575f, 0.405557f, 0.538935f, 0.862195f, 0.544292f, 0.410861f, 0.505648f, 0.437161f, 0.071189f, 0.409054f, 0.537057f, 0.864310f, 0.544899f, 0.393280f, 0.496453f, 0.437731f, 0.072069f, 0.410655f, 0.490903f, 0.262608f, 0.397790f, 0.576431f, 0.497750f, 0.421786f, 0.436751f, 0.865953f, 0.460771f, 0.256774f, 0.412950f, 0.587792f, 0.466756f, 0.403886f, 0.468647f, 0.867206f, 0.510473f, 0.271384f, 0.393981f, 0.554321f, 0.513314f, 0.436839f, 0.425226f, 0.861592f, 0.410089f, 0.811830f, 0.596259f, 0.187669f, 0.517756f, 0.659598f, 0.669717f, 0.464897f, 0.412661f, 0.810483f, 0.599894f, 0.188068f, 0.521217f, 0.661192f, 0.669521f, 0.464721f, 0.419832f, 0.776928f, 0.601120f, 0.181108f, 0.524039f, 0.659178f, 0.664446f, 0.452487f, 0.501286f, 0.819148f, 0.451081f, 0.397779f, 0.578520f, 0.423382f, 0.158569f, 0.410852f, 0.472688f, 0.816542f, 0.456622f, 0.355501f, 0.572246f, 0.415580f, 0.171291f, 0.385089f, 0.461622f, 0.820564f, 0.458296f, 0.366592f, 0.570558f, 0.420437f, 0.160864f, 0.388639f, 0.569614f, 0.345505f, 0.387646f, 0.603702f, 0.428365f, 0.410074f, 0.408705f, 0.831131f, 0.563201f, 0.340637f, 0.390655f, 0.579616f, 0.437486f, 0.404695f, 0.423845f, 0.829337f, 0.562342f, 0.340988f, 0.390814f, 0.614464f, 0.422665f, 0.411699f, 0.410831f, 0.834902f, 0.358183f, 0.741080f, 0.523627f, 0.245418f, 0.493038f, 0.643709f, 0.543054f, 0.481367f, 0.365844f, 0.750295f, 0.542675f, 0.237461f, 0.496768f, 0.656764f, 0.565628f, 0.484818f, 0.357216f, 0.731385f, 0.524484f, 0.251065f, 0.495546f, 0.648040f, 0.532744f, 0.484403f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + true, false, true // disable_cpu, disable_cuda, disable_dml (GQA with flash attention only works on CUDA) + ); +} + +// GQA only supports fp16 and bf16 in current implementation. TEST(AttentionTest, Attention4DGqaAttnMask) { int batch_size = 2; // Q.shape[0] int q_num_heads = 9; // Q.shape[1] @@ -930,6 +965,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) { ); } +// GQA only supports fp16 and bf16 in current implementation. TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { int batch_size = 2; // Q.shape[0] int q_num_heads = 9; // Q.shape[1] diff --git a/onnxruntime/test/python/transformers/test_onnx_attention.py b/onnxruntime/test/python/transformers/test_onnx_attention.py new file mode 100644 index 0000000000000..e9356a9ec7746 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_attention.py @@ -0,0 +1,1129 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# ------------------------------------------------------------------------- +import math +import os +import unittest +from dataclasses import dataclass + +import numpy +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import ( + InferenceSession, + SessionOptions, + get_available_providers, + get_build_info, +) + +# Set seed for reproducibility +torch.manual_seed(0) + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +# Number of values per parameter (compared to pipeline mode) +param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 + +# When quick build is used, flash attention only supports head_size=128 +quick_build = ", quick-build=" in get_build_info() + +enable_debug_print = quick_build + +enable_deterministic_check = True + +# ################################################################################################# +# Configuration and Helper Classes +# ################################################################################################# + + +# --- ONNX and Torch/Numpy Dtype Mappings --- +ONNX_TENSOR_TYPE_MAP = { + "float32": TensorProto.FLOAT, + "float16": TensorProto.FLOAT16, + "bfloat16": TensorProto.BFLOAT16, + "int32": TensorProto.INT32, + "int8": TensorProto.INT8, + "int4": TensorProto.UINT8, +} + +TORCH_DTYPE_TO_ONNX_MAP = { + torch.float32: TensorProto.FLOAT, + torch.float16: TensorProto.FLOAT16, + torch.bfloat16: TensorProto.BFLOAT16, + torch.int32: TensorProto.INT32, + torch.int8: TensorProto.INT8, +} + +TORCH_DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int8": torch.int8, + "int4": torch.uint8, +} + + +@dataclass +class AttentionConfig: + batch_size: int + q_sequence_length: int + kv_sequence_length: int + q_num_heads: int + kv_num_heads: int + head_size: int + is_causal: int = 0 + past_kv_sequence_length: int = 0 + softcap: float = 0.0 + kv_cache_type: str = "" + has_attn_mask: bool = False + + +# ################################################################################################# +# ONNX Graph Creation +# ################################################################################################# + + +def create_attention_node_and_io( + config: AttentionConfig, + ort_type, + is_past=False, + output_qk: int = 0, # CUDA does not support output_qk for GQA path +): + """ + Create ONNX Attention op node and I/O definitions for testing. + + ONNX Attention op (opset 23) inputs: + - 0: Q (query) - required + - 1: K (key) - required + - 2: V (value) - required + - 3: attn_mask - optional + - 4: past_key - optional + - 5: past_value - optional + + ONNX Attention op outputs: + - 0: Y (output) + - 1: present_key (optional) + - 2: present_value (optional) + - 3: output_qk (optional) + """ + # For ONNX Attention op, present KV cache grows (not fixed buffer like GQA) + if is_past: + past_kv_seqlen = config.past_kv_sequence_length + present_kv_seqlen = config.past_kv_sequence_length + config.kv_sequence_length + else: # Prompt + past_kv_seqlen = 0 + present_kv_seqlen = config.kv_sequence_length + + if not config.kv_cache_type: + config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + + # --- Node Definition --- + outputs = [ + "output", + "present_key", + "present_value", + ] + + if output_qk > 0: + outputs.append("output_qk") + + # ONNX Attention op inputs: Q, K, V, attn_mask, past_key, past_value + # attn_mask is used as padding mask (additive bias: 0.0 for valid, -inf for masked) + inputs = [ + "query", + "key", + "value", + "attn_mask" if config.has_attn_mask else "", + "past_key" if is_past else "", + "past_value" if is_past else "", + ] + + # Remove trailing empty strings + while inputs and inputs[-1] == "": + inputs.pop() + + # ONNX Attention op attributes (opset 23) + node = helper.make_node( + op_type="Attention", + inputs=inputs, + outputs=outputs, + name="Attention_0", + is_causal=config.is_causal, + kv_num_heads=config.kv_num_heads, + q_num_heads=config.q_num_heads, + softcap=config.softcap, + qk_matmul_output_mode=output_qk, + domain="", # ai.onnx domain + ) + + # --- Graph Inputs --- + # ONNX Attention op uses 3D inputs: [batch, seq_len, hidden_size] + q_hidden_size = config.q_num_heads * config.head_size + kv_hidden_size = config.kv_num_heads * config.head_size + + graph_input = [ + helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), + helper.make_tensor_value_info("key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size]), + helper.make_tensor_value_info( + "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + ), + ] + + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + + # attn_mask for ONNX Attention op - boolean padding mask + # GQA path expects shape: [batch, total_seq_len] - True for valid, False for masked + # The kernel converts this to seqlens_k internally + if config.has_attn_mask: + mask_seq_len = present_kv_seqlen + graph_input.append( + helper.make_tensor_value_info("attn_mask", TensorProto.BOOL, [config.batch_size, mask_seq_len]) + ) + + # past_key and past_value for ONNX Attention op + # Shape: [batch, num_heads, past_seq_len, head_size] (4D BNSH format) + if is_past: + past_k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + graph_input.extend( + [ + helper.make_tensor_value_info("past_key", cache_ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", cache_ort_type, past_k_shape), + ] + ) + + # --- Graph Outputs --- + output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + + graph_output = [ + helper.make_tensor_value_info( + "output", ort_type, [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + ), + helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), + ] + + if output_qk > 0: + graph_output.append( + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.q_num_heads, config.q_sequence_length, present_kv_seqlen], + ) + ) + + return node, graph_input, graph_output + + +def create_attention_graph_prompt(config: AttentionConfig, ort_type): + """Create ONNX graph for prompt phase (no past KV cache).""" + node, graph_input, graph_output = create_attention_node_and_io(config, ort_type, is_past=False) + graph = helper.make_graph([node], "Attention_Graph", graph_input, graph_output) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 23)]) + return model.SerializeToString() + + +def create_attention_graph_past(config: AttentionConfig, ort_type): + """Create ONNX graph for decoding phase (with past KV cache).""" + node, graph_input, graph_output = create_attention_node_and_io(config, ort_type, is_past=True) + graph = helper.make_graph([node], "Attention_Graph", graph_input, graph_output) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 23)]) + return model.SerializeToString() + + +# ################################################################################################# +# ONNX Runtime Execution Functions +# ################################################################################################# + + +def bind_tensor(io_binding, name, tensor, device, ort_type): + # Helper to bind a tensor to ONNX Runtime based on its device and type + if tensor is None: + return + # Assuming tensor is a torch tensor. This works for both CPU and GPU tensors. + io_binding.bind_input( + name, + tensor.device.type, + 0, + ort_type, + tuple(tensor.shape), + tensor.data_ptr(), + ) + + +def bind_output_tensor(io_binding, name, tensor, device, ort_type): + if tensor is None: + return + io_binding.bind_output( + name, + tensor.device.type, + 0, + ort_type, + tuple(tensor.shape), + tensor.data_ptr(), + ) + + +def attention_prompt_func( + q, + k, + v, + config: AttentionConfig, + attn_mask, + ep, + device, + ort_type=TensorProto.FLOAT16, +): + """ + Run ONNX Attention op for prompt phase (no past KV cache). + + Args: + q: Query tensor [batch, q_seq_len, q_num_heads, head_size] + k: Key tensor [batch, kv_seq_len, kv_num_heads, head_size] + v: Value tensor [batch, kv_seq_len, kv_num_heads, head_size] + config: AttentionConfig with model parameters + attn_mask: Optional attention mask tensor (additive bias, 0.0 for valid, -inf for masked) + ep: Execution provider (e.g., "CUDAExecutionProvider") + device: Device string (e.g., "cuda") + ort_type: ONNX tensor type + """ + if not config.kv_cache_type: + config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + + onnx_model_str = create_attention_graph_prompt( + config=config, + ort_type=ort_type, + ) + + # Reshape to 3D [batch, seq_len, hidden_size] + q_3d = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + k_3d = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v_3d = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Bind inputs + bind_tensor(io_binding, "query", q_3d, device, ort_type) + bind_tensor(io_binding, "key", k_3d, device, ort_type) + bind_tensor(io_binding, "value", v_3d, device, ort_type) + + # Bind optional attention mask (boolean padding mask: True=valid, False=masked) + if config.has_attn_mask and attn_mask is not None: + bind_tensor(io_binding, "attn_mask", attn_mask, device, TensorProto.BOOL) + + # Bind Outputs + hidden_size = config.q_num_heads * config.head_size + + out_dtype = torch.float16 + if ort_type == TensorProto.BFLOAT16: + out_dtype = torch.bfloat16 + elif ort_type == TensorProto.FLOAT16: + out_dtype = torch.float16 + else: + out_dtype = torch.float32 + + out_torch = torch.zeros((config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device) + bind_output_tensor(io_binding, "output", out_torch, device, ort_type) + + # present KV shape for prompt (no past) + present_seqlen = config.kv_sequence_length + present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + + # Determine dtype for cache tensors + cache_dtype = out_dtype + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + + present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) + bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) + + ort_session.run_with_iobinding(io_binding) + + return out_torch, present_k, present_v + + +def attention_past_func( + q, + past_k, + past_v, + new_k, + new_v, + config: AttentionConfig, + attn_mask, + ep, + device, + ort_type=TensorProto.FLOAT16, +): + """ + Run ONNX Attention op for decoding phase (with past KV cache). + + Args: + q: Query tensor [batch, q_seq_len, q_num_heads, head_size] + past_k: Past key tensor [batch, kv_num_heads, past_seq_len, head_size] + past_v: Past value tensor [batch, kv_num_heads, past_seq_len, head_size] + new_k: New key tensor [batch, kv_seq_len, kv_num_heads, head_size] + new_v: New value tensor [batch, kv_seq_len, kv_num_heads, head_size] + config: AttentionConfig with model parameters + attn_mask: Optional attention mask tensor (additive bias, 0.0 for valid, -inf for masked) + ep: Execution provider (e.g., "CUDAExecutionProvider") + device: Device string (e.g., "cuda") + ort_type: ONNX tensor type + """ + if not config.kv_cache_type: + config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" + + onnx_model_str = create_attention_graph_past( + config=config, + ort_type=ort_type, + ) + + # Reshape to 3D [batch, seq_len, hidden_size] + q_3d = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + new_k_3d = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v_3d = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Total sequence length for present KV + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length + + # Bind inputs + bind_tensor(io_binding, "query", q_3d, device, ort_type) + bind_tensor(io_binding, "key", new_k_3d, device, ort_type) + bind_tensor(io_binding, "value", new_v_3d, device, ort_type) + + # Bind optional attention mask (boolean padding mask: True=valid, False=masked) + if config.has_attn_mask and attn_mask is not None: + bind_tensor(io_binding, "attn_mask", attn_mask, device, TensorProto.BOOL) + + # Bind past_key and past_value + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + + # past_k and past_v should be sliced to actual past length + past_len = config.past_kv_sequence_length + past_k_sliced = past_k[:, :, :past_len, :].contiguous() + past_v_sliced = past_v[:, :, :past_len, :].contiguous() + bind_tensor(io_binding, "past_key", past_k_sliced, device, cache_ort_type) + bind_tensor(io_binding, "past_value", past_v_sliced, device, cache_ort_type) + + # Bind Outputs + hidden_size = config.q_num_heads * config.head_size + + out_dtype = torch.float16 + if ort_type == TensorProto.BFLOAT16: + out_dtype = torch.bfloat16 + elif ort_type == TensorProto.FLOAT16: + out_dtype = torch.float16 + else: + out_dtype = torch.float32 + + out_torch = torch.zeros((config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device) + bind_output_tensor(io_binding, "output", out_torch, device, ort_type) + + # present KV shape (past + new) + present_seqlen = total_seq_len + present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + + cache_dtype = out_dtype + present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) + bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) + + ort_session.run_with_iobinding(io_binding) + + return out_torch, present_k, present_v + + +# ################################################################################################# +# Reference Attention Implementation +# ################################################################################################# + + +def construct_causal_mask(seqlen_q, seqlen_k, device): + """Construct a causal mask for attention.""" + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + # Causal: positions can only attend to earlier positions + return col_idx > row_idx + seqlen_k - seqlen_q + + +def attention_ref( + q, + k, + v, + key_padding_mask=None, + causal=False, + softcap=0.0, +): + """ + Reference implementation of scaled dot-product attention with GQA support. + + Args: + q: Query tensor [batch, seq_q, num_heads, head_size] + k: Key tensor [batch, seq_k, kv_num_heads, head_size] + v: Value tensor [batch, seq_k, kv_num_heads, head_size] + key_padding_mask: Boolean mask [batch, seq_k] - True for valid, False for masked + causal: Whether to apply causal masking + softcap: Softcap value for attention scores (0.0 = disabled) + + Returns: + output: Attention output [batch, seq_q, num_heads, head_size] + attention: Attention weights [batch, num_heads, seq_q, seq_k] + """ + dtype_og = q.dtype + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + + # Repeat K/V heads for Grouped-Query Attention + if k.shape[2] != q.shape[2]: + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + if v.shape[2] != q.shape[2]: + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + + scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + + if softcap > 0: + scores = (scores / softcap).tanh() * softcap + + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + + if causal: + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device) + scores.masked_fill_(causal_mask, float("-inf")) + + attention = torch.softmax(scores, dim=-1) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +# ################################################################################################# +# Parity Check (Core Test Logic) +# ################################################################################################# + + +def parity_check_attention_prompt( + config: AttentionConfig, + ep, + device, + torch_type, + ort_type, + causal, + rtol, + atol, + std=0.2, +): + """ + Parity check for ONNX Attention op in prompt phase (no past KV cache). + + This tests that the ONNX Attention op produces the same output as a PyTorch + reference implementation for the initial prompt processing. + """ + torch.manual_seed(0) + + # Generate Q, K, V tensors in BSNH format (batch, seq, num_heads, head_size) + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) * std + + # --- Create attn_mask as boolean padding mask (simulating seqlens_k) --- + # For testing, we use full sequence length (no actual padding) + # attn_mask: [batch, kv_seq_len] - True for valid, False for masked + # GQA kernel converts this to seqlens_k internally + attn_mask = None + key_padding_mask = None + if config.has_attn_mask: + # All positions are valid (no padding) for this test + # Create a 2D boolean mask of True (all valid positions) + attn_mask = torch.ones( + config.batch_size, + config.kv_sequence_length, + device=device, + dtype=torch.bool, + ) + # key_padding_mask for reference: all True (all valid) + key_padding_mask = torch.ones( + config.batch_size, + config.kv_sequence_length, + device=device, + dtype=torch.bool, + ) + + # --- PyTorch Reference Path --- + out_ref, _ = attention_ref( + q=q, + k=k, + v=v, + key_padding_mask=key_padding_mask, + causal=causal, + softcap=config.softcap, + ) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep=ep, + device=device, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + if present_k is not None: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + if present_v is not None: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + out_np = out.to(torch.float32).detach().cpu().numpy() + + # --- Comparison --- + # Check for NaN in output + nan_count = numpy.sum(numpy.isnan(out_np)) + if nan_count > 0: + nan_indices = numpy.argwhere(numpy.isnan(out_np)) + print(f"DEBUG_NAN: Found {nan_count} NaN values in output!") + print(f"DEBUG_NAN: First 5 NaN indices: {nan_indices[:5]}") + + # Compare KV cache (present_k should match k, present_v should match v) + # K/V are in BSNH, present_k/v are in BNSH - need to transpose for comparison + k_ref_bnsh = k.transpose(1, 2) # BSNH -> BNSH + v_ref_bnsh = v.transpose(1, 2) # BSNH -> BNSH + + k_ref_np = k_ref_bnsh.to(torch.float32).detach().cpu().numpy() + v_ref_np = v_ref_bnsh.to(torch.float32).detach().cpu().numpy() + present_k_np = present_k.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v.to(torch.float32).detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, k_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, v_ref_np, rtol=rtol, atol=atol) + + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +def parity_check_attention_past( + config: AttentionConfig, + ep, + device, + torch_type, + ort_type, + causal, + rtol, + atol, + std=0.2, +): + """ + Parity check for ONNX Attention op in decoding phase (with past KV cache). + + This tests that the ONNX Attention op produces the same output as a PyTorch + reference implementation for token-by-token decoding with KV cache. + """ + if ort_type == TensorProto.FLOAT16: + torch_type = torch.float16 + elif ort_type == TensorProto.BFLOAT16: + torch_type = torch.bfloat16 + else: + torch_type = torch.float32 + torch.manual_seed(0) + + # --- Test Data Generation --- + # Query for new tokens + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + + # Past KV cache in BNSH format + past_k = ( + torch.randn( + config.batch_size, + config.kv_num_heads, + config.past_kv_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + past_v = torch.randn_like(past_k) * std + + # New K/V for current tokens in BSNH format + new_k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + new_v = torch.randn_like(new_k) * std + + # --- PyTorch Reference Path --- + # Concatenate past and new KV for reference + # past_k is BNSH, new_k is BSNH - need to transpose new_k + new_k_bnsh = new_k.transpose(1, 2) # BSNH -> BNSH + new_v_bnsh = new_v.transpose(1, 2) # BSNH -> BNSH + + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) # [B, N, past+new, H] + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) # [B, N, past+new, H] + + # Convert to BSNH for reference attention + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length + + # --- Create attn_mask as boolean padding mask (simulating seqlens_k) --- + # For testing, we use full sequence length (no actual padding) + # attn_mask: [batch, total_seq_len] - True for valid, False for masked + # GQA kernel converts this to seqlens_k internally + attn_mask = None + key_padding_mask = None + if config.has_attn_mask: + # All positions are valid (no padding) for this test + attn_mask = torch.ones( + config.batch_size, + total_seq_len, + device=device, + dtype=torch.bool, + ) + # key_padding_mask for reference: all True (all valid) + key_padding_mask = torch.ones( + config.batch_size, + total_seq_len, + device=device, + dtype=torch.bool, + ) + + out_ref, _ = attention_ref( + q=q, + k=full_k_bsnh, + v=full_v_bsnh, + key_padding_mask=key_padding_mask, + causal=causal, + softcap=config.softcap, + ) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=attn_mask, + ep=ep, + device=device, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + if present_k is not None: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + if present_v is not None: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) + + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + out_np = out.to(torch.float32).detach().cpu().numpy() + + if enable_debug_print: + print(f"[DEBUG] out_np non-zeros: {numpy.count_nonzero(out_np)} / {out_np.size}") + print(f"[DEBUG] out_ref_np non-zeros: {numpy.count_nonzero(out_ref_np)} / {out_ref_np.size}") + + if numpy.count_nonzero(out_ref_np) > 0 and numpy.count_nonzero(out_np) == 0: + raise RuntimeError("Output is all zeros") + + # --- Comparison --- + # Compare KV cache (present should be concat of past + new) + full_k_ref_np = full_k_bnsh.to(torch.float32).detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.to(torch.float32).detach().cpu().numpy() + present_k_np = present_k.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v.to(torch.float32).detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol, atol=atol) + + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +# ################################################################################################# +# Test Utilities +# ################################################################################################# + + +def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): + """ + Print percentile statistics (75%, 95%, 99%) for a difference tensor. + This helps assess parity quality beyond just max difference. + + Args: + diff_tensor: Tensor containing absolute differences between expected and actual outputs. + prefix: Optional prefix string for the output message. + """ + if not enable_debug_print: + return + + diff_flat = diff_tensor.flatten().float() + if diff_flat.numel() == 0: + print(f"{prefix}Diff statistics: empty tensor") + return + + # Compute percentiles + sorted_diff, _ = torch.sort(diff_flat) + n = sorted_diff.numel() + + p75_idx = min(int(n * 0.75), n - 1) + p90_idx = min(int(n * 0.90), n - 1) + p95_idx = min(int(n * 0.95), n - 1) + p99_idx = min(int(n * 0.99), n - 1) + p999_idx = min(int(n * 0.999), n - 1) + + p75 = sorted_diff[p75_idx].item() + p90 = sorted_diff[p90_idx].item() + p95 = sorted_diff[p95_idx].item() + p99 = sorted_diff[p99_idx].item() + p999 = sorted_diff[p999_idx].item() + max_val = sorted_diff[-1].item() + mean_val = diff_flat.mean().item() + + print( + f"{prefix} Diff stats - mean: {mean_val:.6f}, p75: {p75:.6f}, p90: {p90:.6f}, p95: {p95:.6f}, p99: {p99:.6f}, p999: {p999:.6f}, max: {max_val:.6f}" + ) + + +# ################################################################################################# +# Test Case Generators +# ################################################################################################# + + +def attention_prompt_test_cases(): + """ + Generate test cases for ONNX Attention op in prompt phase. + + The ONNX Attention op (opset 23) supports: + - GQA (kv_num_heads != q_num_heads) + - MHA (kv_num_heads == q_num_heads) + - Causal attention via is_causal attribute + - softcap + + It does NOT support (handled by external ops): + - Rotary embeddings + - Smooth softmax / head_sink + - Local window attention + - Packed QKV + """ + batches = [1, 2, 3] + seqs = [(16, 16), (64, 64), (128, 128)] + # GQA head configurations only (kv_heads != q_heads) + heads = [(8, 2), (8, 4)] # (q_heads, kv_heads) + h_sizes = [128] if quick_build else [64, 128] + softcap_opts = [0.0] # softcap not yet supported in CUDA implementation + + h_sizes_to_test = h_sizes[:1] if pipeline_mode else h_sizes + + combo_index = 0 + for h in h_sizes_to_test: + for b in batches[:2] if pipeline_mode else batches: + for sq, skv in seqs[:2] if pipeline_mode else seqs: + for n, n2 in heads: + softcap = softcap_opts[combo_index % len(softcap_opts)] + combo_index += 1 + + config = AttentionConfig( + batch_size=b, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n2, + head_size=h, + is_causal=1, # Causal attention + softcap=softcap, + ) + name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_sc{softcap}" + yield name, config + + +def attention_past_test_cases(): + """ + Generate test cases for ONNX Attention op in decoding phase (with past KV cache). + """ + batches = [1, 2] + # (new_seq_len, past_seq_len) + seqs = [(1, 32), (1, 128), (1, 512)] + # GQA head configurations only (kv_heads != q_heads) + heads = [(8, 2), (8, 4)] # (q_heads, kv_heads) + h_sizes = [128] if quick_build else [64, 128] + softcap_opts = [0.0] + + h_sizes_to_test = h_sizes[:1] if pipeline_mode else h_sizes + + combo_index = 0 + for h in h_sizes_to_test: + for b in batches[:1] if pipeline_mode else batches: + for s, s2 in seqs[:2] if pipeline_mode else seqs: + for n, n2 in heads: + softcap = softcap_opts[combo_index % len(softcap_opts)] + combo_index += 1 + + config = AttentionConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, # new K/V has same length as Q + past_kv_sequence_length=s2, + q_num_heads=n, + kv_num_heads=n2, + head_size=h, + is_causal=1, # Causal attention + softcap=softcap, + ) + name = f"b{b}_s{s}_past{s2}_nh{n}_{n2}_h{h}_sc{softcap}" + yield name, config + + +# ################################################################################################# +# Unit Test Classes +# ################################################################################################# + + +def has_cuda_provider(): + return "CUDAExecutionProvider" in get_available_providers() + + +def has_cuda_device(min_capability: int = 80): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor >= min_capability + + +def has_flash_attention(): + return has_cuda_device(80) + + +rtol = {"fp16": 5e-3, "bf16": 5e-2} +atol = {"fp16": 5e-3, "bf16": 1e-2} + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestONNXAttentionFlashGQA(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Flash Attention.""" + + @parameterized.expand(attention_prompt_test_cases()) + def test_attention_prompt_flash(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_attention_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(attention_past_test_cases()) + def test_attention_past_flash(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_attention_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestONNXAttentionFlashGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Flash Attention using BFloat16.""" + + @parameterized.expand(attention_prompt_test_cases()) + def test_attention_prompt_flash_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_attention_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + @parameterized.expand(attention_past_test_cases()) + def test_attention_past_flash_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_attention_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +class TestONNXAttentionMemoryEfficientGQA(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention.""" + + @parameterized.expand(attention_prompt_test_cases()) + def test_attention_prompt_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_attention_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(attention_past_test_cases()) + def test_attention_past_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_attention_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") +class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" + + @parameterized.expand(attention_past_test_cases()) + def test_attention_past_memory_efficient_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_attention_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 54e24cd1e0a83..298927299c53f 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -41,8 +41,6 @@ "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen "^test_l2normalization*", // LpNormalization(22) not implemented // TODO: support the following tests in Attention-cuda - "^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda - "^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda "^test_attention_3d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda "^test_attention_4d_diff_heads_sizes_softcap_cuda", // softcap not supported in Attention-cuda "^test_attention_3d_softcap_cuda", // softcap not supported in Attention-cuda