@@ -231,12 +231,18 @@ class SparseAdagradOp final : public Operator<CPUContext> {
231231 if (block_size != last_block_size_) {
232232 last_block_size_ = block_size;
233233 if (std::is_same<SIndex, std::int32_t >::value) {
234- kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t >(
235- block_size, /* rowwise=*/ false , /* prefetch=*/ 16 , weight_decay_);
234+ kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t >(
235+ block_size,
236+ /* rowwise=*/ false ,
237+ /* prefetch=*/ 16 ,
238+ weight_decay_ != 0 .0f );
236239 } else {
237240 CAFFE_ENFORCE ((std::is_same<SIndex, std::int64_t >::value));
238- kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t >(
239- block_size, /* rowwise=*/ false , /* prefetch=*/ 16 , weight_decay_);
241+ kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t >(
242+ block_size,
243+ /* rowwise=*/ false ,
244+ /* prefetch=*/ 16 ,
245+ weight_decay_ != 0 .0f );
240246 }
241247 }
242248
@@ -250,7 +256,8 @@ class SparseAdagradOp final : public Operator<CPUContext> {
250256 momentOut,
251257 reinterpret_cast <const std::int32_t *>(indices),
252258 epsilon_,
253- lr[0 ]);
259+ lr[0 ],
260+ weight_decay_);
254261 } else {
255262 num_rows_processed = kernel_i64_ (
256263 n,
@@ -260,7 +267,8 @@ class SparseAdagradOp final : public Operator<CPUContext> {
260267 momentOut,
261268 reinterpret_cast <const std::int64_t *>(indices),
262269 epsilon_,
263- lr[0 ]);
270+ lr[0 ],
271+ weight_decay_);
264272 }
265273 if (num_rows_processed < n) {
266274 CAFFE_ENFORCE_GE (
@@ -340,10 +348,10 @@ class SparseAdagradOp final : public Operator<CPUContext> {
340348
341349 protected:
342350 float epsilon_;
343- float weight_decay_;
351+ const float weight_decay_;
344352#if defined(USE_FBGEMM) && !defined(__NVCC__)
345- fbgemm::SparseAdaGradSignature<std::int32_t >::Type kernel_i32_;
346- fbgemm::SparseAdaGradSignature<std::int64_t >::Type kernel_i64_;
353+ fbgemm::SparseAdaGradSignature<std::int32_t >::NewType kernel_i32_;
354+ fbgemm::SparseAdaGradSignature<std::int64_t >::NewType kernel_i64_;
347355 std::int64_t last_block_size_{-1 };
348356#endif
349357
@@ -420,12 +428,18 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
420428 if (block_size != last_block_size_) {
421429 last_block_size_ = block_size;
422430 if (std::is_same<SIndex, std::int32_t >::value) {
423- kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t >(
424- block_size, /* rowwise=*/ true , /* prefetch=*/ 16 , weight_decay_);
431+ kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t >(
432+ block_size,
433+ /* rowwise=*/ true ,
434+ /* prefetch=*/ 16 ,
435+ weight_decay_ != 0 .0f );
425436 } else {
426437 CAFFE_ENFORCE ((std::is_same<SIndex, std::int64_t >::value));
427- kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t >(
428- block_size, /* rowwise=*/ true , /* prefetch=*/ 16 , weight_decay_);
438+ kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t >(
439+ block_size,
440+ /* rowwise=*/ true ,
441+ /* prefetch=*/ 16 ,
442+ weight_decay_ != 0 .0f );
429443 }
430444 }
431445
@@ -439,7 +453,8 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
439453 moment,
440454 reinterpret_cast <const std::int32_t *>(indices),
441455 epsilon_,
442- lr[0 ]);
456+ lr[0 ],
457+ weight_decay_);
443458 } else {
444459 num_rows_processed = kernel_i64_ (
445460 n,
@@ -449,7 +464,8 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
449464 moment,
450465 reinterpret_cast <const std::int64_t *>(indices),
451466 epsilon_,
452- lr[0 ]);
467+ lr[0 ],
468+ weight_decay_);
453469 }
454470
455471 if (num_rows_processed < n) {
@@ -527,10 +543,10 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
527543
528544 protected:
529545 float epsilon_;
530- float weight_decay_;
546+ const float weight_decay_;
531547#if defined(USE_FBGEMM) && !defined(__NVCC__)
532- fbgemm::SparseAdaGradSignature<std::int32_t >::Type kernel_i32_;
533- fbgemm::SparseAdaGradSignature<std::int64_t >::Type kernel_i64_;
548+ fbgemm::SparseAdaGradSignature<std::int32_t >::NewType kernel_i32_;
549+ fbgemm::SparseAdaGradSignature<std::int64_t >::NewType kernel_i64_;
534550 std::int64_t last_block_size_{-1 };
535551#endif
536552
0 commit comments