diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index e4fae63ce7cd..933e37b67718 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -63,14 +63,16 @@ class Adam_Optimizer { _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; + if(step!=_step){ + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } } } } diff --git a/csrc/xpu/includes/cpu_adam.h b/csrc/xpu/includes/cpu_adam.h index 7bc0364c569d..7cc068027fde 100644 --- a/csrc/xpu/includes/cpu_adam.h +++ b/csrc/xpu/includes/cpu_adam.h @@ -69,14 +69,16 @@ class Adam_Optimizer { _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { + if(step!=_step){ + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { _betta1_t *= _betta1; _betta2_t *= _betta2; + } } } }