Skip to content

Commit e04d660

Browse files
committed
Fix builds
1 parent 9579426 commit e04d660

File tree

5 files changed

+14
-4
lines changed

5 files changed

+14
-4
lines changed

onnxruntime/test/mlas/unittest/test_fgemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class FgemmPackedContext<double, false> {
8787
data[i].alpha = alpha;
8888
data[i].beta = beta;
8989
}
90-
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool, nullptr, nullptr);
90+
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool, nullptr);
9191
}
9292
};
9393
#endif

orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
169169
skip_im2col ? Xdata + group_id * X_offset : col_buffer_data,
170170
1,
171171
dWdata + group_id * W_offset,
172-
tp);
172+
tp, &mlas_backend_kernel_selector_config_);
173173
}
174174
}
175175
if (dB) {
@@ -207,7 +207,7 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
207207
dYdata,
208208
0,
209209
col_buffer_data,
210-
tp);
210+
tp, &mlas_backend_kernel_selector_config_);
211211

212212
if (kernel_rank == 2) {
213213
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(

orttraining/orttraining/training_ops/cpu/nn/conv_grad.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include "core/framework/op_kernel.h"
77
#include "core/providers/cpu/nn/conv_attributes.h"
8+
#include "core/mlas/inc/mlas.h"
9+
#include "core/session/onnxruntime_session_options_config_keys.h"
810

911
namespace onnxruntime {
1012
namespace contrib {
@@ -13,6 +15,8 @@ template <typename T>
1315
class ConvGrad final : public OpKernel {
1416
public:
1517
explicit ConvGrad(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
18+
mlas_backend_kernel_selector_config_.use_kleidiai =
19+
info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasDisableKleidiai) != "1";
1620
}
1721

1822
Status Compute(OpKernelContext* context) const override;
@@ -22,6 +26,7 @@ class ConvGrad final : public OpKernel {
2226

2327
private:
2428
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvGrad);
29+
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
2530
};
2631

2732
} // namespace contrib

orttraining/orttraining/training_ops/cpu/op_gradients.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
148148
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
149149
math::Gemm<float>(CblasNoTrans, CblasNoTrans, n, d, 1, -1,
150150
scaledata, sum_multiplier_.data(), 1,
151-
dXdata, tp);
151+
dXdata, tp, &mlas_backend_kernel_selector_config_);
152152

153153
math::Mul<float, CPUMathUtil>(gsl::narrow_cast<int>(Y.Shape().Size()), dXdata, Ydata, dXdata, nullptr);
154154
}

orttraining/orttraining/training_ops/cpu/op_gradients.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#pragma once
55

66
#include "core/framework/op_kernel.h"
7+
#include "core/mlas/inc/mlas.h"
8+
#include "core/session/onnxruntime_session_options_config_keys.h"
79
#include <cctype>
810

911
namespace onnxruntime {
@@ -67,6 +69,8 @@ class SoftmaxGrad final : public OpKernel {
6769
opset_ = (node.OpType() == "SoftmaxGrad_13" || node.OpType() == "LogSoftmaxGrad_13") ? 13 : 1;
6870
axis_ = info.GetAttrOrDefault("axis", static_cast<int64_t>(opset_ < 13 ? 1 : -1));
6971
is_logsoftmaxgrad_ = node.OpType() == "LogSoftmaxGrad_13" || node.OpType() == "LogSoftmaxGrad";
72+
mlas_backend_kernel_selector_config_.use_kleidiai =
73+
info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasDisableKleidiai) != "1";
7074
}
7175

7276
Status Compute(OpKernelContext* context) const override;
@@ -76,6 +80,7 @@ class SoftmaxGrad final : public OpKernel {
7680
int64_t axis_;
7781
int opset_; // opset_ of the forward Softmax operator
7882
bool is_logsoftmaxgrad_;
83+
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
7984
};
8085

8186
template <typename T>

0 commit comments

Comments
 (0)