@@ -223,13 +223,17 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
223223 // ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
224224 // dL/drt = (dL/dah * Rh) (.) (Ht-1h) ---------- (5)
225225 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , hidden_size_,
226- hidden_size_, alpha, grad_ah, Rh, weight_beta, grad_ar, thread_pool_);
226+ hidden_size_, alpha, grad_ah, Rh, weight_beta, grad_ar, thread_pool_,
227+ // TODO(hasesh): Pass through mlas backend config when available
228+ nullptr /* mlas_backend_kernel_selector_config*/ );
227229 ElementwiseProduct (grad_ar, Htminus1, grad_ar, hidden_size_);
228230 } else {
229231 // ah = Xth * Wh^T + rt (.) (Ht-1h * Rh^T + Rbh) + Wbh
230232 // dL/drt = dL/dah (.) (Ht-1h * Rh^T + Rbh) ---------- (5)
231233 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasTrans, 1 , hidden_size_,
232- hidden_size_, alpha, Htminus1, Rh, weight_beta, grad_ar, thread_pool_);
234+ hidden_size_, alpha, Htminus1, Rh, weight_beta, grad_ar, thread_pool_,
235+ // TODO(hasesh): Pass through mlas backend config when available
236+ nullptr /* mlas_backend_kernel_selector_config*/ );
233237 if (Rbh != nullptr )
234238 deepcpu::elementwise_sum1 (Rbh, grad_ar, hidden_size_);
235239 ElementwiseProduct (grad_ar, grad_ah, grad_ar, hidden_size_);
@@ -258,22 +262,28 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
258262 float * grad_Xt = SafeRawPointer<T>(outputs.grad_input .begin () + X_offset,
259263 outputs.grad_input .end (), input_size_);
260264 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , input_size_,
261- hidden_size_, alpha, grad_az, Wz, input_beta, grad_Xt, thread_pool_);
265+ hidden_size_, alpha, grad_az, Wz, input_beta, grad_Xt, thread_pool_,
266+ // TODO(hasesh): Pass through mlas backend config when available
267+ nullptr /* mlas_backend_kernel_selector_config*/ );
262268
263269 // ar = Xtr * Wr^T + Ht-1r * Rr^T + Wbr + Rbr
264270 // dL/dXtr = dL/dar * Wr ---------- (9)
265271 // [1, input_size_] = [1, hidden_size_] * [hidden_size_, input_size_]
266272 // M = 1, N = input_size_, K = hidden_size_
267273 input_beta = 1 .0f ;
268274 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , input_size_,
269- hidden_size_, alpha, grad_ar, Wr, input_beta, grad_Xt, thread_pool_);
275+ hidden_size_, alpha, grad_ar, Wr, input_beta, grad_Xt, thread_pool_,
276+ // TODO(hasesh): Pass through mlas backend config when available
277+ nullptr /* mlas_backend_kernel_selector_config*/ );
270278
271279 // ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
272280 // dL/dXth = dL/dah * Wh ---------- (10)
273281 // [1, input_size_] = [1, hidden_size_] * [hidden_size_, input_size_]
274282 // M = 1, N = input_size_, K = hidden_size_
275283 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , input_size_,
276- hidden_size_, alpha, grad_ah, Wh, input_beta, grad_Xt, thread_pool_);
284+ hidden_size_, alpha, grad_ah, Wh, input_beta, grad_Xt, thread_pool_,
285+ // TODO(hasesh): Pass through mlas backend config when available
286+ nullptr /* mlas_backend_kernel_selector_config*/ );
277287 }
278288
279289 if (grad_weights_required) {
@@ -287,7 +297,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
287297 const float * Xt = SafeRawPointer<const T>(inputs.input .begin () + X_offset,
288298 inputs.input .end (), input_size_);
289299 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
290- 1 , alpha, grad_az, Xt, weight_beta, grad_Wz_local, thread_pool_);
300+ 1 , alpha, grad_az, Xt, weight_beta, grad_Wz_local, thread_pool_,
301+ // TODO(hasesh): Pass through mlas backend config when available
302+ nullptr /* mlas_backend_kernel_selector_config*/ );
291303 // Note that the weight beta is always 0. So, we must accumulate ourselves.
292304 deepcpu::elementwise_sum1 (grad_Wz_local, grad_Wz, hidden_size_ * input_size_);
293305
@@ -296,7 +308,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
296308 // [hidden_size_, input_size_] = [1, hidden_size_]^T * [1, input_size_]
297309 // M = hidden_size_, N = input_size_, K = 1
298310 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
299- 1 , alpha, grad_ar, Xt, weight_beta, grad_Wr_local, thread_pool_);
311+ 1 , alpha, grad_ar, Xt, weight_beta, grad_Wr_local, thread_pool_,
312+ // TODO(hasesh): Pass through mlas backend config when available
313+ nullptr /* mlas_backend_kernel_selector_config*/ );
300314 // Note that the weight beta is always 0. So, we must accumulate ourselves.
301315 deepcpu::elementwise_sum1 (grad_Wr_local, grad_Wr, hidden_size_ * input_size_);
302316
@@ -305,7 +319,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
305319 // [hidden_size_, input_size_] = [1, hidden_size_]^T * [1, input_size_]
306320 // M = hidden_size_, N = input_size_, K = 1
307321 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, input_size_,
308- 1 , alpha, grad_ah, Xt, weight_beta, grad_Wh_local, thread_pool_);
322+ 1 , alpha, grad_ah, Xt, weight_beta, grad_Wh_local, thread_pool_,
323+ // TODO(hasesh): Pass through mlas backend config when available
324+ nullptr /* mlas_backend_kernel_selector_config*/ );
309325 // Note that the weight beta is always 0. So, we must accumulate ourselves.
310326 deepcpu::elementwise_sum1 (grad_Wh_local, grad_Wh, hidden_size_ * input_size_);
311327 }
@@ -316,7 +332,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
316332 // [hidden_size_, hidden_size_] = [1, hidden_size_]^T * [1, hidden_size_]
317333 // M = hidden_size_, N = hidden_size_, K = 1
318334 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
319- 1 , alpha, grad_az, Htminus1, weight_beta, grad_Rz_local, thread_pool_);
335+ 1 , alpha, grad_az, Htminus1, weight_beta, grad_Rz_local, thread_pool_,
336+ // TODO(hasesh): Pass through mlas backend config when available
337+ nullptr /* mlas_backend_kernel_selector_config*/ );
320338 // Note that the weight beta is always 0. So, we must accumulate ourselves.
321339 deepcpu::elementwise_sum1 (grad_Rz_local, grad_Rz, hidden_size_ * hidden_size_);
322340
@@ -325,7 +343,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
325343 // [hidden_size_, hidden_size_] = [1, hidden_size_]^T * [1, hidden_size_]
326344 // M = hidden_size_, N = hidden_size_, K = 1
327345 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
328- 1 , alpha, grad_ar, Htminus1, weight_beta, grad_Rr_local, thread_pool_);
346+ 1 , alpha, grad_ar, Htminus1, weight_beta, grad_Rr_local, thread_pool_,
347+ // TODO(hasesh): Pass through mlas backend config when available
348+ nullptr /* mlas_backend_kernel_selector_config*/ );
329349 // Note that the weight beta is always 0. So, we must accumulate ourselves.
330350 deepcpu::elementwise_sum1 (grad_Rr_local, grad_Rr, hidden_size_ * hidden_size_);
331351
@@ -336,7 +356,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
336356 // M = hidden_size_, N = hidden_size_, K = 1
337357 ElementwiseProduct (rt, Htminus1, rt_factor, hidden_size_);
338358 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
339- 1 , alpha, grad_ah, rt_factor, weight_beta, grad_Rh_local, thread_pool_);
359+ 1 , alpha, grad_ah, rt_factor, weight_beta, grad_Rh_local, thread_pool_,
360+ // TODO(hasesh): Pass through mlas backend config when available
361+ nullptr /* mlas_backend_kernel_selector_config*/ );
340362 // Note that the weight beta is always 0. So, we must accumulate ourselves.
341363 deepcpu::elementwise_sum1 (grad_Rh_local, grad_Rh, hidden_size_ * hidden_size_);
342364 } else {
@@ -347,7 +369,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
347369 // M = hidden_size_, N = hidden_size_, K = 1
348370 ElementwiseProduct (grad_ah, rt, rt_factor, hidden_size_);
349371 ::onnxruntime::math::Gemm<float >(CblasTrans, CblasNoTrans, hidden_size_, hidden_size_,
350- 1 , alpha, rt_factor, Htminus1, weight_beta, grad_Rh_local, thread_pool_);
372+ 1 , alpha, rt_factor, Htminus1, weight_beta, grad_Rh_local, thread_pool_,
373+ // TODO(hasesh): Pass through mlas backend config when available
374+ nullptr /* mlas_backend_kernel_selector_config*/ );
351375 // Note that the weight beta is always 0. So, we must accumulate ourselves.
352376 deepcpu::elementwise_sum1 (grad_Rh_local, grad_Rh, hidden_size_ * hidden_size_);
353377 }
@@ -402,14 +426,18 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
402426 // [1, hidden_size_] = [1, hidden_size_] * [hidden_size_, hidden_size_]
403427 // M = 1, N = hidden_size_, K = hidden_size_
404428 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , hidden_size_,
405- hidden_size_, alpha, grad_az, Rz, recurrence_input_beta, grad_Ht, thread_pool_);
429+ hidden_size_, alpha, grad_az, Rz, recurrence_input_beta, grad_Ht, thread_pool_,
430+ // TODO(hasesh): Pass through mlas backend config when available
431+ nullptr /* mlas_backend_kernel_selector_config*/ );
406432
407433 // ar = Xtr * Wr^T + Ht-1r * Rr^T + Wbr + Rbr
408434 // dL/dHt-1r = dL/dar * Rr ---------- (26)
409435 // [1, hidden_size_] = [1, hidden_size_] * [hidden_size_, hidden_size_]
410436 // M = 1, N = hidden_size_, K = hidden_size_
411437 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , hidden_size_,
412- hidden_size_, alpha, grad_ar, Rr, recurrence_input_beta, grad_Ht, thread_pool_);
438+ hidden_size_, alpha, grad_ar, Rr, recurrence_input_beta, grad_Ht, thread_pool_,
439+ // TODO(hasesh): Pass through mlas backend config when available
440+ nullptr /* mlas_backend_kernel_selector_config*/ );
413441
414442 if (!linear_before_reset_) {
415443 // ah = Xth * Wh^T + (rt (.) Ht-1h) * Rh^T + Wbh + Rbh
@@ -421,7 +449,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
421449 // to store the intermediate result (making sure to clear the results in grad_ar before writing to it).
422450 recurrence_input_beta = 0 .0f ;
423451 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , hidden_size_,
424- hidden_size_, alpha, grad_ah, Rh, recurrence_input_beta, grad_ar, thread_pool_);
452+ hidden_size_, alpha, grad_ah, Rh, recurrence_input_beta, grad_ar, thread_pool_,
453+ // TODO(hasesh): Pass through mlas backend config when available
454+ nullptr /* mlas_backend_kernel_selector_config*/ );
425455 deepcpu::elementwise_product (grad_ar, rt, grad_Ht, hidden_size_);
426456 } else {
427457 // ah = Xth * Wh^T + rt (.) (Ht-1h * Rh^T + Rbh) + Wbh
@@ -432,7 +462,9 @@ void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs, GRUGradOutp
432462 recurrence_input_beta = 1 .0f ;
433463 ElementwiseProduct (grad_ah, rt, rt_factor, hidden_size_);
434464 ::onnxruntime::math::Gemm<float >(CblasNoTrans, CblasNoTrans, 1 , hidden_size_,
435- hidden_size_, alpha, rt_factor, Rh, recurrence_input_beta, grad_Ht, thread_pool_);
465+ hidden_size_, alpha, rt_factor, Rh, recurrence_input_beta, grad_Ht, thread_pool_,
466+ // TODO(hasesh): Pass through mlas backend config when available
467+ nullptr /* mlas_backend_kernel_selector_config*/ );
436468 }
437469 }
438470 }
0 commit comments