Skip to content

Commit 9329b0d

Browse files
committed
Fix builds
1 parent e0affbb commit 9329b0d

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

onnxruntime/contrib_ops/cpu/cdist.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ DEFINE_KERNEL(float);
1919
DEFINE_KERNEL(double);
2020

2121
template <typename T>
22-
static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, concurrency::ThreadPool* threadpool) {
22+
static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, concurrency::ThreadPool* threadpool,
23+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) {
2324
// input shapes have already been validated
2425
const auto& shape_a = a.Shape().GetDims(); // {m, k}
2526
const auto& shape_b = b.Shape().GetDims(); // {n, k}
@@ -64,7 +65,8 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co
6465
m, n, k,
6566
static_cast<T>(-2.), a_data, b_data, static_cast<T>(0.),
6667
c_data,
67-
threadpool);
68+
threadpool,
69+
mlas_backend_kernel_selector_config);
6870
#else
6971
// the performance of this isn't great as the eigen matmul is single threaded by default
7072
// if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimizing this
@@ -114,7 +116,7 @@ common::Status CDist<T>::Compute(OpKernelContext* context) const {
114116
Tensor* C = context->Output(0, output_shape);
115117
T* output = C->MutableData<T>();
116118

117-
CalculateSqeuclidean<T>(*A, *B, *C, tp);
119+
CalculateSqeuclidean<T>(*A, *B, *C, tp, &mlas_backend_kernel_selector_config_);
118120
auto map_out = EigenVectorArrayMap<T>(output, narrow<size_t>(output_shape.Size()));
119121

120122
// because we use GEMM in CalculateSqeuclidean there's a slight chance a number extremely close to zero

onnxruntime/contrib_ops/cpu/cdist.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include "core/common/common.h"
77
#include "core/framework/op_kernel.h"
8+
#include "core/session/onnxruntime_session_options_config_keys.h"
9+
#include "core/mlas/inc/mlas.h"
810

911
namespace onnxruntime {
1012
namespace contrib {
@@ -17,8 +19,13 @@ class CDist final : public OpKernel {
1719
enum class Mode { EUCLIDEAN,
1820
SQEUCLIDEAN } mode_;
1921

22+
23+
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
24+
2025
public:
2126
CDist(const OpKernelInfo& info) : OpKernel(info) {
27+
mlas_backend_kernel_selector_config_.use_kleidiai =
28+
info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasDisableKleidiai) != "1";
2229
std::string metric;
2330
ORT_ENFORCE(info.GetAttr<std::string>("metric", &metric).IsOK());
2431
if (metric.compare("sqeuclidean") == 0)

0 commit comments

Comments
 (0)