Skip to content

Commit da033e0

Browse files
jspark1105facebook-github-bot
authored andcommitted
[Caffe2] use new fbgemm sparse adagrad interface with temp name (pytorch#46089)
Summary: Pull Request resolved: pytorch#46089 Follow-up of D24195799 Test Plan: . Reviewed By: dskhudia Differential Revision: D24196753 fbshipit-source-id: 216512822cfb752984bb97bd229af9746e866eaa
1 parent 0ddcc0c commit da033e0

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

caffe2/sgd/adagrad_op.h

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,18 @@ class SparseAdagradOp final : public Operator<CPUContext> {
231231
if (block_size != last_block_size_) {
232232
last_block_size_ = block_size;
233233
if (std::is_same<SIndex, std::int32_t>::value) {
234-
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
235-
block_size, /*rowwise=*/false, /*prefetch=*/16, weight_decay_);
234+
kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t>(
235+
block_size,
236+
/*rowwise=*/false,
237+
/*prefetch=*/16,
238+
weight_decay_ != 0.0f);
236239
} else {
237240
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
238-
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
239-
block_size, /*rowwise=*/false, /*prefetch=*/16, weight_decay_);
241+
kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t>(
242+
block_size,
243+
/*rowwise=*/false,
244+
/*prefetch=*/16,
245+
weight_decay_ != 0.0f);
240246
}
241247
}
242248

@@ -250,7 +256,8 @@ class SparseAdagradOp final : public Operator<CPUContext> {
250256
momentOut,
251257
reinterpret_cast<const std::int32_t*>(indices),
252258
epsilon_,
253-
lr[0]);
259+
lr[0],
260+
weight_decay_);
254261
} else {
255262
num_rows_processed = kernel_i64_(
256263
n,
@@ -260,7 +267,8 @@ class SparseAdagradOp final : public Operator<CPUContext> {
260267
momentOut,
261268
reinterpret_cast<const std::int64_t*>(indices),
262269
epsilon_,
263-
lr[0]);
270+
lr[0],
271+
weight_decay_);
264272
}
265273
if (num_rows_processed < n) {
266274
CAFFE_ENFORCE_GE(
@@ -340,10 +348,10 @@ class SparseAdagradOp final : public Operator<CPUContext> {
340348

341349
protected:
342350
float epsilon_;
343-
float weight_decay_;
351+
const float weight_decay_;
344352
#if defined(USE_FBGEMM) && !defined(__NVCC__)
345-
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
346-
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
353+
fbgemm::SparseAdaGradSignature<std::int32_t>::NewType kernel_i32_;
354+
fbgemm::SparseAdaGradSignature<std::int64_t>::NewType kernel_i64_;
347355
std::int64_t last_block_size_{-1};
348356
#endif
349357

@@ -420,12 +428,18 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
420428
if (block_size != last_block_size_) {
421429
last_block_size_ = block_size;
422430
if (std::is_same<SIndex, std::int32_t>::value) {
423-
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
424-
block_size, /*rowwise=*/true, /*prefetch=*/16, weight_decay_);
431+
kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t>(
432+
block_size,
433+
/*rowwise=*/true,
434+
/*prefetch=*/16,
435+
weight_decay_ != 0.0f);
425436
} else {
426437
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
427-
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
428-
block_size, /*rowwise=*/true, /*prefetch=*/16, weight_decay_);
438+
kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t>(
439+
block_size,
440+
/*rowwise=*/true,
441+
/*prefetch=*/16,
442+
weight_decay_ != 0.0f);
429443
}
430444
}
431445

@@ -439,7 +453,8 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
439453
moment,
440454
reinterpret_cast<const std::int32_t*>(indices),
441455
epsilon_,
442-
lr[0]);
456+
lr[0],
457+
weight_decay_);
443458
} else {
444459
num_rows_processed = kernel_i64_(
445460
n,
@@ -449,7 +464,8 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
449464
moment,
450465
reinterpret_cast<const std::int64_t*>(indices),
451466
epsilon_,
452-
lr[0]);
467+
lr[0],
468+
weight_decay_);
453469
}
454470

455471
if (num_rows_processed < n) {
@@ -527,10 +543,10 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
527543

528544
protected:
529545
float epsilon_;
530-
float weight_decay_;
546+
const float weight_decay_;
531547
#if defined(USE_FBGEMM) && !defined(__NVCC__)
532-
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
533-
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
548+
fbgemm::SparseAdaGradSignature<std::int32_t>::NewType kernel_i32_;
549+
fbgemm::SparseAdaGradSignature<std::int64_t>::NewType kernel_i64_;
534550
std::int64_t last_block_size_{-1};
535551
#endif
536552

0 commit comments

Comments
 (0)