From fba4117fd37db6c8903545cc7942b09749c6b4ab Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 3 Feb 2026 11:23:47 -0800 Subject: [PATCH 1/7] Fix Conv LHS packing padding/uninitialized ptrs V2 (#27215) ### Description Refer to V1 of the fix here: https://github.com/microsoft/onnxruntime/pull/27214 This PR includes all fixes from the V1 PR + logic to invalidate the lhs cache pointers in case the pad buffer's underlying buffer has changed due to a resize. The ARM team will look at potentially enhancing this logic after the 1.24.0 release. ### Motivation and Context Fix #26669 --- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index 94332c9ed34bc..5f9d121232a27 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -395,6 +395,12 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], std::default_delete()); + // Initialize all padding entries. For partial tiles (m < m_step), + // the kai LHS packing kernel may still read pointer entries beyond the logically + // filled 'm' positions. Leaving these uninitialized can cause non-deterministic + // reads and corrupt packed LHS data. + auto lhs_ptrs_ = lhs_ptrs.get(); + std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast(&pad_ptr[0])); auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); @@ -430,7 +436,6 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i }; size_t m_{0}; - auto lhs_ptrs_ = lhs_ptrs.get(); for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { size_t k_{0}; @@ -460,7 +465,23 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s // figure out how many blocks needed to correctly fill padding padsize = ((ci + padsize - 1) / padsize) * padsize; } - static std::vectorpad_ptr(padsize, 0.f); + + // pad_ptr must be at least 'ci' floats for padding pixels. + // Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct. + thread_local std::vector pad_ptr; + const float* old_pad_ptr = pad_ptr.data(); + bool has_pad_ptr_changed = false; + + if (pad_ptr.size() < padsize) { + pad_ptr.resize(padsize, 0.f); + if (pad_ptr.data() != old_pad_ptr) { + has_pad_ptr_changed = true; + } + } else { + // Ensure any previously-used region remains zeroed (grow-only means it should already be zeros, + // but keep this explicit for safety). + std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f); + } LhsCacheKey key = { ci, ih, iw, @@ -481,6 +502,16 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. thread_local std::unordered_map> lhs_ptrs_cache; + if (has_pad_ptr_changed) + { + // If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they + // would be referencing the old pad buffer. + // See discussion in https://github.com/microsoft/onnxruntime/pull/27214. + // TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key + // or any other approach that would reduce unnecessary cache invalidations. + lhs_ptrs_cache.clear(); + } + std::shared_ptr lhs_ptrs; if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { lhs_ptrs = found->second; From 4abba28369b5f8c0245db7762518150dc2ab255b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 3 Feb 2026 14:25:11 -0500 Subject: [PATCH 2/7] Fix WebGPU MoE swiglu_limit (default to infinity) (#27221) ### Description According to https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmoe, > swiglu_limit : float The limit used to clamp in SwiGLU. No clamp when limit is not provided. However, currently, the default is set to 0, meaning we clamp to 0 if no limit is provided. ### Motivation and Context Fixes #27220. See there for bug description and reproduction. Hoping to get this in before 1.24.0 releases. cc @guschmue --- onnxruntime/contrib_ops/webgpu/moe/moe.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/moe/moe.h b/onnxruntime/contrib_ops/webgpu/moe/moe.h index 5e329dc12b5c9..332aa39a8d23e 100755 --- a/onnxruntime/contrib_ops/webgpu/moe/moe.h +++ b/onnxruntime/contrib_ops/webgpu/moe/moe.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/webgpu_kernel.h" @@ -31,7 +33,7 @@ class MoE : public WebGpuKernel { activation_alpha_ = static_cast(info.GetAttrOrDefault("activation_alpha", 1.0)); activation_beta_ = static_cast(info.GetAttrOrDefault("activation_beta", 1.0)); swiglu_fusion_ = static_cast(info.GetAttrOrDefault("swiglu_fusion", 0)); - swiglu_limit_ = info.GetAttrOrDefault("swiglu_limit", 0); + swiglu_limit_ = info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); k_ = static_cast(info.GetAttrOrDefault("k", 4)); normalize_routing_weights_ = info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; use_sparse_mixer_ = info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; From efabc055bb49d109700b9cccf38f35e1edd0e11c Mon Sep 17 00:00:00 2001 From: umangb-09 Date: Wed, 4 Feb 2026 00:55:38 +0530 Subject: [PATCH 3/7] Fix for #25145 (#26994) ### Description Fixed for fallback provider logic bug when creating inference session can lead to losing GPU acceleration ### Motivation and Context Fixing this for the PR here [#25145](https://github.com/microsoft/onnxruntime/issues/25145) --- .../onnxruntime_inference_collection.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index ac1dd8b5a2ae7..1aa28cfd45873 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -521,8 +521,25 @@ def __init__( def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): available_providers = C.get_available_providers() - # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. - if "TensorrtExecutionProvider" in available_providers: + # Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified + if providers: + has_tensorrt = any( + provider == "TensorrtExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") + for provider in providers + ) + has_tensorrt_rtx = any( + provider == "NvTensorRTRTXExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") + for provider in providers + ) + if has_tensorrt and has_tensorrt_rtx: + raise ValueError( + "Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' " + "in the same session." + ) + # Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. + if "NvTensorRTRTXExecutionProvider" in available_providers: if ( providers and any( @@ -531,15 +548,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi for provider in providers ) and any( - provider == "TensorrtExecutionProvider" - or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") + provider == "NvTensorRTRTXExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") for provider in providers ) ): self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] - if "NvTensorRTRTXExecutionProvider" in available_providers: + elif "TensorrtExecutionProvider" in available_providers: if ( providers and any( @@ -548,8 +565,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi for provider in providers ) and any( - provider == "NvTensorRTRTXExecutionProvider" - or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider") + provider == "TensorrtExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") for provider in providers ) ): From 25a6fdcaab9901c6822aaec64bae9daf24d4fa68 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 3 Feb 2026 13:17:52 -0800 Subject: [PATCH 4/7] Specify attention-23 kernel and relax assertion in prepare qkv (#27217) This pull request updates the attention kernel selection logic and clarifies support for unidirectional (causal) attention in the CUDA attention implementation. The main changes focus on improving documentation, removing outdated comments, and explicitly setting the kernel type for better maintainability and clarity. Kernel selection and configuration improvements: * Explicitly set the `kernel_type` field to `AttentionKernel_Unfused` in the `AttentionData` structure to clarify which kernel is being used and improve future extensibility. Documentation and code clarity: * Added comments to clarify that unidirectional (causal) attention is supported by several attention kernel implementations, and that the TRT fused runner is only used for non-unidirectional cases, as enforced elsewhere. * Removed outdated TODO comments regarding parameter continuation and kernel selection, as these are now handled more explicitly in the code. [[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL194) [[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL223-R227) --- onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu | 4 +++- onnxruntime/core/providers/cuda/llm/attention.cc | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 7b37a3a4227b6..852f0bcaff5a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -258,7 +258,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, assert(data.past_value == nullptr); assert(data.present_key == nullptr); assert(data.present_value == nullptr); - assert(!parameters.is_unidirectional); + // Note: is_unidirectional (causal) is supported by flash attention, memory efficient attention, + // cuDNN flash attention, and unfused kernel. TRT fused runner is only used when !is_unidirectional + // (enforced in MultiHeadAttention::ComputeInternal). assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 99f297bba6444..3b7aebc2d4714 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -191,7 +191,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); } - // TODO(titaiwang): Continue on these parameters // Construct AttentionData to pass to QkvToContext typedef typename ToCudaType::MappedType CudaT; onnxruntime::contrib::cuda::AttentionData data; @@ -220,12 +219,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.qkv_format = contribop_parameters.qkv_format; - // TODO: Determine which kernel to use (Flash Attention, Memory Efficient Attention, etc.) // For now, set flags to false and let QkvToContext use the unfused path data.use_flash_attention = false; data.use_memory_efficient_attention = false; data.fused_runner = nullptr; data.fused_cross_attention_kernel = nullptr; + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; // Allocate workspace for Q, K, V processing and scratch buffer const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); From 260a48c8696aae5e41f05dcf0e7553eece7709cc Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 3 Feb 2026 17:10:48 -0800 Subject: [PATCH 5/7] [CUDA] Run FlashAttention regression test only when FlashAttention is available (#27206) ### Description As title. Checking if FlashAttention exists check includes if torch has CUDA support, the system has the right device to run FlashAttention, etc. ### Motivation and Context Fix Windows CUDA CI failures --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/python/transformers/test_gqa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 9cbe2a01698ae..e800c22f92efb 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -1775,6 +1775,7 @@ def test_flash_decode_parity(self): del os.environ["ORT_DISABLE_FLASH_DECODE"] +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestGQARegressions(unittest.TestCase): """Specific regression tests for historical bugs.""" From 685895c5ceb5700141cabf80c1c0d5ea3f4c794b Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 3 Feb 2026 22:52:48 -0800 Subject: [PATCH 6/7] [CPU/CUDA] Add bias input validations for ConvTranspose (#27209) ### Description Takeaway from https://github.com/microsoft/onnxruntime/issues/26144 Resolve https://github.com/microsoft/onnxruntime/issues/26144 ### Motivation and Context Improve kernel input validation for ConvTranspose --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../cpu/nn/conv_transpose_attributes.h | 12 ++++ .../core/providers/cuda/nn/conv_transpose.cc | 12 ++++ .../cpu/nn/conv_transpose_op_test.cc | 68 +++++++++++++++++++ 3 files changed, 92 insertions(+) diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index 973743d711359..d93a630788b1b 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -99,6 +99,18 @@ struct ConvTransposeAttributes : public ConvAttributes { " group: ", group); } + // Bias shape validation (It should be a 1D tensor with size M) + // See https://github.com/microsoft/onnxruntime/issues/26144 + if (B != nullptr) { + if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Bias shape is not compatible with number of output channels." + " It should be a 1-D tensor with size num_output_channels(M).", + " Bias: ", B->Shape(), + " num_output_channels: ", num_output_channels); + } + } + TensorShapeVector kernel_shape; ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape, is_nhwc)); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2972ae999adc4..28197c20af052 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -311,6 +311,18 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna " group: ", conv_transpose_attrs_.group); } + // Bias shape validation (It should be a 1D tensor with size M) + // See https://github.com/microsoft/onnxruntime/issues/26144 + if (B != nullptr) { + if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Bias shape is not compatible with number of output channels." + " It should be a 1-D tensor with size num_output_channels(M).", + " Bias: ", B->Shape(), + " num_output_channels: ", num_output_channels); + } + } + TensorShapeVector kernel_shape; ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc)); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 198fa07ae4ed0..53c76b702529f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -529,6 +529,74 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { kDmlExecutionProvider}); // TODO: Unskip when fixed #41968513 } +TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_1) { + ConvTransposeOpAttributes attrs = { + vector{1, 5}, // kernel_shape + {}, // output_padding + vector{2, 1, 1, 14}, // output_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + vector{1, 1}, // dilations + 1, // group + "NOTSET" // auto_pad + }; + vector X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f}; + vector X_shape = {2, 1, 1, 10}; + vector W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f}; + vector W_shape = {1, 1, 1, 5}; + vector B = {1.0f, 2.0f}; // invalid bias shape, should be {1} + vector B_shape = {2}; // invalid bias shape, should be {1} + vector Y_shape = {2, 1, 1, 14}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, + OpTester::ExpectResult::kExpectFailure, + // Just ensure that it starts with the expected string. + "Bias shape is not compatible with number of output channels. " + "It should be a 1-D tensor with size num_output_channels(M).", + // The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which + // also tests for invalid shapes. It also includes XnnPack which seems to have its own + // way of dealing with incorrectly shaped bias. + {kTensorrtExecutionProvider, kQnnExecutionProvider, + kDmlExecutionProvider, kXnnpackExecutionProvider, + kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed +} + +TEST(ConvTransposeTest, ConvTranspose_InvalidBiasShape_2) { + ConvTransposeOpAttributes attrs = { + vector{1, 5}, // kernel_shape + {}, // output_padding + vector{2, 1, 1, 14}, // output_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + vector{1, 1}, // dilations + 1, // group + "NOTSET" // auto_pad + }; + vector X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f}; + vector X_shape = {2, 1, 1, 10}; + vector W = {1.0f, 2.0f, 3.0f, 2.0f, 1.0f}; + vector W_shape = {1, 1, 1, 5}; + vector B = {1.0f, 2.0f}; + vector B_shape = {1, 2}; // invalid bias rank (it should be 1-D) + vector Y_shape = {2, 1, 1, 14}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, + OpTester::ExpectResult::kExpectFailure, + // Just ensure that it starts with the expected string. + "Bias shape is not compatible with number of output channels. " + "It should be a 1-D tensor with size num_output_channels(M).", + // The EP exclusions are along the same lines as ConvTranspose_InvalidKernelShape which + // also tests for invalid shapes. It also includes XnnPack which seems to have its own + // way of dealing with incorrectly shaped bias. + {kTensorrtExecutionProvider, kQnnExecutionProvider, + kDmlExecutionProvider, kXnnpackExecutionProvider, + kWebGpuExecutionProvider}); // Remove when https://github.com/microsoft/onnxruntime/issues/27210 is fixed +} + TEST(ConvTransposeTest, ConvTranspose_onnx) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape From 4f424ded34c0d07381552d0691b2bbfa9b60639f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 20:32:35 +0000 Subject: [PATCH 7/7] Bump protobuf in /onnxruntime/python/tools/transformers/models/llama Bumps [protobuf](https://github.com/protocolbuffers/protobuf) from 4.25.8 to 5.29.6. - [Release notes](https://github.com/protocolbuffers/protobuf/releases) - [Commits](https://github.com/protocolbuffers/protobuf/commits) --- updated-dependencies: - dependency-name: protobuf dependency-version: 5.29.6 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- .../python/tools/transformers/models/llama/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index ee11227cd3acc..e62cf7344649e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -5,5 +5,5 @@ transformers==4.52.1 torch>=2.7.0 onnx==1.18.0 datasets>=2.8.0 -protobuf==4.25.8 +protobuf==5.29.6 psutil