Skip to content

Commit b68029c

Browse files
committed
Fix builds
1 parent d37b146 commit b68029c

File tree

5 files changed

+104
-34
lines changed

5 files changed

+104
-34
lines changed

onnxruntime/core/providers/cpu/cpu_provider_shared.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ struct EinsumTypedComputeProcessor {
290290
static void operator delete(void* p) { g_host_cpu.EinsumTypedComputeProcessor__operator_delete(reinterpret_cast<EinsumTypedComputeProcessor*>(p)); }
291291
static std::unique_ptr<EinsumTypedComputeProcessor> Create(OpKernelContext* context, AllocatorPtr allocator,
292292
concurrency::ThreadPool* tp,
293+
const void* mlas_backend_config,
293294
EinsumComputePreprocessor& einsum_compute_preprocessor,
294295
void* einsum_cuda_assets);
295296

orttraining/orttraining/training_ops/cpu/rnn/gru.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ Status GRUTraining<T>::Compute(OpKernelContext* context) const {
4141
attributes_.activation_funcs.Entries()[1],
4242
attributes_.clip,
4343
context->GetOperatorThreadPool(),
44-
true /*training_mode*/);
44+
true /*training_mode*/,
45+
// TODO(hasesh): Pass through mlas backend config when available
46+
nullptr /*mlas_backend_kernel_selector_config*/);
4547
gru.Compute(gru_inputs.input,
4648
gru_inputs.sequence_lengths,
4749
attributes_.num_directions,

orttraining/orttraining/training_ops/cpu/rnn/gru_grad_compute.cc

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,17 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
223223
// ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
224224
// dL/drt = (dL/dah * Rh) (.) (Ht-1h) ---------- (5)
225225
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, hidden_size_,
226-
hidden_size_, alpha, grad_ah, Rh, weight_beta, grad_ar, thread_pool_);
226+
hidden_size_, alpha, grad_ah, Rh, weight_beta, grad_ar, thread_pool_,
227+
// TODO(hasesh): Pass through mlas backend config when available
228+
nullptr /*mlas_backend_kernel_selector_config*/);
227229
ElementwiseProduct(grad_ar, Htminus1, grad_ar, hidden_size_);
228230
} else {
229231
// ah = Xth * Wh^T + rt (.) (Ht-1h * Rh^T + Rbh) + Wbh
230232
// dL/drt = dL/dah (.) (Ht-1h * Rh^T + Rbh) ---------- (5)
231233
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasTrans, 1, hidden_size_,
232-
hidden_size_, alpha, Htminus1, Rh, weight_beta, grad_ar, thread_pool_);
234+
hidden_size_, alpha, Htminus1, Rh, weight_beta, grad_ar, thread_pool_,
235+
// TODO(hasesh): Pass through mlas backend config when available
236+
nullptr /*mlas_backend_kernel_selector_config*/);
233237
if (Rbh != nullptr)
234238
deepcpu::elementwise_sum1(Rbh, grad_ar, hidden_size_);
235239
ElementwiseProduct(grad_ar, grad_ah, grad_ar, hidden_size_);
@@ -258,22 +262,28 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
258262
float* grad_Xt = SafeRawPointer<T>(outputs.grad_input.begin() + X_offset,
259263
outputs.grad_input.end(), input_size_);
260264
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, input_size_,
261-
hidden_size_, alpha, grad_az, Wz, input_beta, grad_Xt, thread_pool_);
265+
hidden_size_, alpha, grad_az, Wz, input_beta, grad_Xt, thread_pool_,
266+
// TODO(hasesh): Pass through mlas backend config when available
267+
nullptr /*mlas_backend_kernel_selector_config*/);
262268

263269
// ar = Xtr * Wr^T + Ht-1r * Rr^T + Wbr + Rbr
264270
// dL/dXtr = dL/dar * Wr ---------- (9)
265271
// [1, input_size_] = [1, hidden_size_] * [hidden_size_, input_size_]
266272
// M = 1, N = input_size_, K = hidden_size_
267273
input_beta = 1.0f;
268274
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, input_size_,
269-
hidden_size_, alpha, grad_ar, Wr, input_beta, grad_Xt, thread_pool_);
275+
hidden_size_, alpha, grad_ar, Wr, input_beta, grad_Xt, thread_pool_,
276+
// TODO(hasesh): Pass through mlas backend config when available
277+
nullptr /*mlas_backend_kernel_selector_config*/);
270278

271279
// ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
272280
// dL/dXth = dL/dah * Wh ---------- (10)
273281
// [1, input_size_] = [1, hidden_size_] * [hidden_size_, input_size_]
274282
// M = 1, N = input_size_, K = hidden_size_
275283
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, input_size_,
276-
hidden_size_, alpha, grad_ah, Wh, input_beta, grad_Xt, thread_pool_);
284+
hidden_size_, alpha, grad_ah, Wh, input_beta, grad_Xt, thread_pool_,
285+
// TODO(hasesh): Pass through mlas backend config when available
286+
nullptr /*mlas_backend_kernel_selector_config*/);
277287
}
278288

279289
if (grad_weights_required) {
@@ -287,7 +297,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
287297
const float* Xt = SafeRawPointer<const T>(inputs.input.begin() + X_offset,
288298
inputs.input.end(), input_size_);
289299
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
290-
1, alpha, grad_az, Xt, weight_beta, grad_Wz_local, thread_pool_);
300+
1, alpha, grad_az, Xt, weight_beta, grad_Wz_local, thread_pool_,
301+
// TODO(hasesh): Pass through mlas backend config when available
302+
nullptr /*mlas_backend_kernel_selector_config*/);
291303
// Note that the weight beta is always 0. So, we must accumulate ourselves.
292304
deepcpu::elementwise_sum1(grad_Wz_local, grad_Wz, hidden_size_ * input_size_);
293305

@@ -296,7 +308,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
296308
// [hidden_size_, input_size_] = [1, hidden_size_]^T * [1, input_size_]
297309
// M = hidden_size_, N = input_size_, K = 1
298310
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
299-
1, alpha, grad_ar, Xt, weight_beta, grad_Wr_local, thread_pool_);
311+
1, alpha, grad_ar, Xt, weight_beta, grad_Wr_local, thread_pool_,
312+
// TODO(hasesh): Pass through mlas backend config when available
313+
nullptr /*mlas_backend_kernel_selector_config*/);
300314
// Note that the weight beta is always 0. So, we must accumulate ourselves.
301315
deepcpu::elementwise_sum1(grad_Wr_local, grad_Wr, hidden_size_ * input_size_);
302316

@@ -305,7 +319,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
305319
// [hidden_size_, input_size_] = [1, hidden_size_]^T * [1, input_size_]
306320
// M = hidden_size_, N = input_size_, K = 1
307321
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
308-
1, alpha, grad_ah, Xt, weight_beta, grad_Wh_local, thread_pool_);
322+
1, alpha, grad_ah, Xt, weight_beta, grad_Wh_local, thread_pool_,
323+
// TODO(hasesh): Pass through mlas backend config when available
324+
nullptr /*mlas_backend_kernel_selector_config*/);
309325
// Note that the weight beta is always 0. So, we must accumulate ourselves.
310326
deepcpu::elementwise_sum1(grad_Wh_local, grad_Wh, hidden_size_ * input_size_);
311327
}
@@ -316,7 +332,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
316332
// [hidden_size_, hidden_size_] = [1, hidden_size_]^T * [1, hidden_size_]
317333
// M = hidden_size_, N = hidden_size_, K = 1
318334
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
319-
1, alpha, grad_az, Htminus1, weight_beta, grad_Rz_local, thread_pool_);
335+
1, alpha, grad_az, Htminus1, weight_beta, grad_Rz_local, thread_pool_,
336+
// TODO(hasesh): Pass through mlas backend config when available
337+
nullptr /*mlas_backend_kernel_selector_config*/);
320338
// Note that the weight beta is always 0. So, we must accumulate ourselves.
321339
deepcpu::elementwise_sum1(grad_Rz_local, grad_Rz, hidden_size_ * hidden_size_);
322340

@@ -325,7 +343,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
325343
// [hidden_size_, hidden_size_] = [1, hidden_size_]^T * [1, hidden_size_]
326344
// M = hidden_size_, N = hidden_size_, K = 1
327345
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
328-
1, alpha, grad_ar, Htminus1, weight_beta, grad_Rr_local, thread_pool_);
346+
1, alpha, grad_ar, Htminus1, weight_beta, grad_Rr_local, thread_pool_,
347+
// TODO(hasesh): Pass through mlas backend config when available
348+
nullptr /*mlas_backend_kernel_selector_config*/);
329349
// Note that the weight beta is always 0. So, we must accumulate ourselves.
330350
deepcpu::elementwise_sum1(grad_Rr_local, grad_Rr, hidden_size_ * hidden_size_);
331351

@@ -336,7 +356,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
336356
// M = hidden_size_, N = hidden_size_, K = 1
337357
ElementwiseProduct(rt, Htminus1, rt_factor, hidden_size_);
338358
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
339-
1, alpha, grad_ah, rt_factor, weight_beta, grad_Rh_local, thread_pool_);
359+
1, alpha, grad_ah, rt_factor, weight_beta, grad_Rh_local, thread_pool_,
360+
// TODO(hasesh): Pass through mlas backend config when available
361+
nullptr /*mlas_backend_kernel_selector_config*/);
340362
// Note that the weight beta is always 0. So, we must accumulate ourselves.
341363
deepcpu::elementwise_sum1(grad_Rh_local, grad_Rh, hidden_size_ * hidden_size_);
342364
} else {
@@ -347,7 +369,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
347369
// M = hidden_size_, N = hidden_size_, K = 1
348370
ElementwiseProduct(grad_ah, rt, rt_factor, hidden_size_);
349371
::onnxruntime::math::Gemm<float>(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
350-
1, alpha, rt_factor, Htminus1, weight_beta, grad_Rh_local, thread_pool_);
372+
1, alpha, rt_factor, Htminus1, weight_beta, grad_Rh_local, thread_pool_,
373+
// TODO(hasesh): Pass through mlas backend config when available
374+
nullptr /*mlas_backend_kernel_selector_config*/);
351375
// Note that the weight beta is always 0. So, we must accumulate ourselves.
352376
deepcpu::elementwise_sum1(grad_Rh_local, grad_Rh, hidden_size_ * hidden_size_);
353377
}
@@ -402,14 +426,18 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
402426
// [1, hidden_size_] = [1, hidden_size_] * [hidden_size_, hidden_size_]
403427
// M = 1, N = hidden_size_, K = hidden_size_
404428
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, hidden_size_,
405-
hidden_size_, alpha, grad_az, Rz, recurrence_input_beta, grad_Ht, thread_pool_);
429+
hidden_size_, alpha, grad_az, Rz, recurrence_input_beta, grad_Ht, thread_pool_,
430+
// TODO(hasesh): Pass through mlas backend config when available
431+
nullptr /*mlas_backend_kernel_selector_config*/);
406432

407433
// ar = Xtr * Wr^T + Ht-1r * Rr^T + Wbr + Rbr
408434
// dL/dHt-1r = dL/dar * Rr ---------- (26)
409435
// [1, hidden_size_] = [1, hidden_size_] * [hidden_size_, hidden_size_]
410436
// M = 1, N = hidden_size_, K = hidden_size_
411437
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, hidden_size_,
412-
hidden_size_, alpha, grad_ar, Rr, recurrence_input_beta, grad_Ht, thread_pool_);
438+
hidden_size_, alpha, grad_ar, Rr, recurrence_input_beta, grad_Ht, thread_pool_,
439+
// TODO(hasesh): Pass through mlas backend config when available
440+
nullptr /*mlas_backend_kernel_selector_config*/);
413441

414442
if (!linear_before_reset_) {
415443
// ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
@@ -421,7 +449,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
421449
// to store the intermediate result (making sure to clear the results in grad_ar before writing to it).
422450
recurrence_input_beta = 0.0f;
423451
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, hidden_size_,
424-
hidden_size_, alpha, grad_ah, Rh, recurrence_input_beta, grad_ar, thread_pool_);
452+
hidden_size_, alpha, grad_ah, Rh, recurrence_input_beta, grad_ar, thread_pool_,
453+
// TODO(hasesh): Pass through mlas backend config when available
454+
nullptr /*mlas_backend_kernel_selector_config*/);
425455
deepcpu::elementwise_product(grad_ar, rt, grad_Ht, hidden_size_);
426456
} else {
427457
// ah = Xth * Wh^T + rt (.) (Ht-1h * Rh^T + Rbh) + Wbh
@@ -432,7 +462,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
432462
recurrence_input_beta = 1.0f;
433463
ElementwiseProduct(grad_ah, rt, rt_factor, hidden_size_);
434464
::onnxruntime::math::Gemm<float>(CblasNoTrans, CblasNoTrans, 1, hidden_size_,
435-
hidden_size_, alpha, rt_factor, Rh, recurrence_input_beta, grad_Ht, thread_pool_);
465+
hidden_size_, alpha, rt_factor, Rh, recurrence_input_beta, grad_Ht, thread_pool_,
466+
// TODO(hasesh): Pass through mlas backend config when available
467+
nullptr /*mlas_backend_kernel_selector_config*/);
436468
}
437469
}
438470
}

orttraining/orttraining/training_ops/cpu/rnn/lstm.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ Status LSTMTraining<T>::Compute(OpKernelContext* context) const {
4545
attributes_.activation_funcs.Entries()[2],
4646
attributes_.clip,
4747
context->GetOperatorThreadPool(),
48-
true);
48+
true
49+
// TODO(hasesh): Pass through mlas backend config when available
50+
nullptr /*mlas_backend_kernel_selector_config*/);
4951

5052
lstm.Compute(lstm_inputs.input,
5153
lstm_inputs.sequence_lengths,

0 commit comments

Comments
 (0)