@@ -19,7 +19,8 @@ DEFINE_KERNEL(float);
1919DEFINE_KERNEL (double );
2020
2121template <typename T>
22- static void CalculateSqeuclidean (const Tensor& a, const Tensor& b, Tensor& c, concurrency::ThreadPool* threadpool) {
22+ static void CalculateSqeuclidean (const Tensor& a, const Tensor& b, Tensor& c, concurrency::ThreadPool* threadpool,
23+ const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) {
2324 // input shapes have already been validated
2425 const auto & shape_a = a.Shape ().GetDims (); // {m, k}
2526 const auto & shape_b = b.Shape ().GetDims (); // {n, k}
@@ -64,7 +65,8 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co
6465 m, n, k,
6566 static_cast <T>(-2 .), a_data, b_data, static_cast <T>(0 .),
6667 c_data,
67- threadpool);
68+ threadpool,
69+ mlas_backend_kernel_selector_config);
6870#else
6971 // the performance of this isn't great as the eigen matmul is single threaded by default
7072 // if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimizing this
@@ -114,7 +116,7 @@ common::Status CDist<T>::Compute(OpKernelContext* context) const {
114116 Tensor* C = context->Output (0 , output_shape);
115117 T* output = C->MutableData <T>();
116118
117- CalculateSqeuclidean<T>(*A, *B, *C, tp);
119+ CalculateSqeuclidean<T>(*A, *B, *C, tp, &mlas_backend_kernel_selector_config_ );
118120 auto map_out = EigenVectorArrayMap<T>(output, narrow<size_t >(output_shape.Size ()));
119121
120122 // because we use GEMM in CalculateSqeuclidean there's a slight chance a number extremely close to zero
0 commit comments