Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(data.past_value == nullptr);
assert(data.present_key == nullptr);
assert(data.present_value == nullptr);
assert(!parameters.is_unidirectional);
// Note: is_unidirectional (causal) is supported by flash attention, memory efficient attention,
// cuDNN flash attention, and unfused kernel. TRT fused runner is only used when !is_unidirectional
// (enforced in MultiHeadAttention::ComputeInternal).
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));

if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/webgpu/moe/moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <limits>

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

Expand Down Expand Up @@ -31,7 +33,7 @@ class MoE : public WebGpuKernel {
activation_alpha_ = static_cast<float>(info.GetAttrOrDefault<float>("activation_alpha", 1.0));
activation_beta_ = static_cast<float>(info.GetAttrOrDefault<float>("activation_beta", 1.0));
swiglu_fusion_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("swiglu_fusion", 0));
swiglu_limit_ = info.GetAttrOrDefault<float>("swiglu_limit", 0);
swiglu_limit_ = info.GetAttrOrDefault<float>("swiglu_limit", std::numeric_limits<float>::infinity());
k_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("k", 4));
normalize_routing_weights_ = info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;
use_sparse_mixer_ = info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
Expand Down
35 changes: 33 additions & 2 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,12 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
std::default_delete<const void*[]>());

// Initialize all padding entries. For partial tiles (m < m_step),
// the kai LHS packing kernel may still read pointer entries beyond the logically
// filled 'm' positions. Leaving these uninitialized can cause non-deterministic
// reads and corrupt packed LHS data.
auto lhs_ptrs_ = lhs_ptrs.get();
std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast<const void*>(&pad_ptr[0]));

auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1);
auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1);
Expand Down Expand Up @@ -430,7 +436,6 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
};

size_t m_{0};
auto lhs_ptrs_ = lhs_ptrs.get();
for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) {
for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) {
size_t k_{0};
Expand Down Expand Up @@ -460,7 +465,23 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// figure out how many blocks needed to correctly fill padding
padsize = ((ci + padsize - 1) / padsize) * padsize;
}
static std::vector<float>pad_ptr(padsize, 0.f);

// pad_ptr must be at least 'ci' floats for padding pixels.
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
thread_local std::vector<float> pad_ptr;
const float* old_pad_ptr = pad_ptr.data();
bool has_pad_ptr_changed = false;

if (pad_ptr.size() < padsize) {
pad_ptr.resize(padsize, 0.f);
if (pad_ptr.data() != old_pad_ptr) {
has_pad_ptr_changed = true;
}
} else {
// Ensure any previously-used region remains zeroed (grow-only means it should already be zeros,
// but keep this explicit for safety).
std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f);
}

LhsCacheKey key = {
ci, ih, iw,
Expand All @@ -481,6 +502,16 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;

if (has_pad_ptr_changed)
{
// If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they
// would be referencing the old pad buffer.
// See discussion in https://github.com/microsoft/onnxruntime/pull/27214.
// TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key
// or any other approach that would reduce unnecessary cache invalidations.
lhs_ptrs_cache.clear();
}

std::shared_ptr<const void*[]> lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
lhs_ptrs = found->second;
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ struct ConvTransposeAttributes : public ConvAttributes {
" group: ", group);
}

// Bias shape validation (It should be a 1D tensor with size M)
// See https://github.com/microsoft/onnxruntime/issues/26144
if (B != nullptr) {
if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Bias shape is not compatible with number of output channels."
" It should be a 1-D tensor with size num_output_channels(M).",
" Bias: ", B->Shape(),
" num_output_channels: ", num_output_channels);
}
}

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape, is_nhwc));

Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA).");
}

// TODO(titaiwang): Continue on these parameters
// Construct AttentionData to pass to QkvToContext
typedef typename ToCudaType<T>::MappedType CudaT;
onnxruntime::contrib::cuda::AttentionData<CudaT> data;
Expand Down Expand Up @@ -220,12 +219,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
data.qkv_format = contribop_parameters.qkv_format;

// TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.)
// For now, set flags to false and let QkvToContext use the unfused path
data.use_flash_attention = false;
data.use_memory_efficient_attention = false;
data.fused_runner = nullptr;
data.fused_cross_attention_kernel = nullptr;
data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused;

// Allocate workspace for Q, K, V processing and scratch buffer
const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data);
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ Status ConvTranspose<T, Layout>::UpdateState(OpKernelContext* context, bool dyna
" group: ", conv_transpose_attrs_.group);
}

// Bias shape validation (It should be a 1D tensor with size M)
// See https://github.com/microsoft/onnxruntime/issues/26144
if (B != nullptr) {
if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Bias shape is not compatible with number of output channels."
" It should be a 1-D tensor with size num_output_channels(M).",
" Bias: ", B->Shape(),
" num_output_channels: ", num_output_channels);
}
}

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc));

Expand Down
31 changes: 24 additions & 7 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,25 @@ def __init__(
def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
available_providers = C.get_available_providers()

# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "TensorrtExecutionProvider" in available_providers:
# Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified
if providers:
has_tensorrt = any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
has_tensorrt_rtx = any(
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
for provider in providers
)
if has_tensorrt and has_tensorrt_rtx:
raise ValueError(
"Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' "
"in the same session."
)
# Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "NvTensorRTRTXExecutionProvider" in available_providers:
if (
providers
and any(
Expand All @@ -531,15 +548,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
for provider in providers
)
):
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]
if "NvTensorRTRTXExecutionProvider" in available_providers:
elif "TensorrtExecutionProvider" in available_providers:
if (
providers
and any(
Expand All @@ -548,8 +565,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ transformers==4.52.1
torch>=2.7.0
onnx==1.18.0
datasets>=2.8.0
protobuf==4.25.8
protobuf==5.29.6
psutil
68 changes: 68 additions & 0 deletions onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,74 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) {
kDmlExecutionProvider}); // TODO: Unskip when fixed #41968513
}

TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_1) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{1, 5}, // kernel_shape
{}, // output_padding
vector<int64_t>{2, 1, 1, 14}, // output_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
vector<int64_t>{1, 1}, // dilations
1, // group
"NOTSET" // auto_pad
};
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f};
vector<int64_t> X_shape = {2, 1, 1, 10};
vector<float> W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
vector<int64_t> W_shape = {1, 1, 1, 5};
vector<float> B = {1.0f, 2.0f}; // invalid bias shape, should be {1}
vector<int64_t> B_shape = {2}; // invalid bias shape, should be {1}
vector<int64_t> Y_shape = {2, 1, 1, 14};
vector<float> expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f,
11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f};
TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
OpTester::ExpectResult::kExpectFailure,
// Just ensure that it starts with the expected string.
"Bias shape is not compatible with number of output channels. "
"It should be a 1-D tensor with size num_output_channels(M).",
// The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which
// also tests for invalid shapes. It also includes XnnPack which seems to have its own
// way of dealing with incorrectly shaped bias.
{kTensorrtExecutionProvider, kQnnExecutionProvider,
kDmlExecutionProvider, kXnnpackExecutionProvider,
kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed
}

TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_2) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{1, 5}, // kernel_shape
{}, // output_padding
vector<int64_t>{2, 1, 1, 14}, // output_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
vector<int64_t>{1, 1}, // dilations
1, // group
"NOTSET" // auto_pad
};
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f};
vector<int64_t> X_shape = {2, 1, 1, 10};
vector<float> W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
vector<int64_t> W_shape = {1, 1, 1, 5};
vector<float> B = {1.0f, 2.0f};
vector<int64_t> B_shape = {1, 2}; // invalid bias rank (it should be 1-D)
vector<int64_t> Y_shape = {2, 1, 1, 14};
vector<float> expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f,
11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f};
TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
OpTester::ExpectResult::kExpectFailure,
// Just ensure that it starts with the expected string.
"Bias shape is not compatible with number of output channels. "
"It should be a 1-D tensor with size num_output_channels(M).",
// The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which
// also tests for invalid shapes. It also includes XnnPack which seems to have its own
// way of dealing with incorrectly shaped bias.
{kTensorrtExecutionProvider, kQnnExecutionProvider,
kDmlExecutionProvider, kXnnpackExecutionProvider,
kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed
}

TEST(ConvTransposeTest, ConvTranspose_onnx) {
ConvTransposeOpAttributes attrs = {
vector<int64_t>{3, 3}, // kernel_shape
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/python/transformers/test_gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,7 @@ def test_flash_decode_parity(self):
del os.environ["ORT_DISABLE_FLASH_DECODE"]


@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.")
class TestGQARegressions(unittest.TestCase):
"""Specific regression tests for historical bugs."""

Expand Down
Loading