diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index e4fae63ce7cd..9163c8851372 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -57,22 +57,25 @@ class Adam_Optimizer { inline void IncrementStep(size_t step, float beta1, float beta2) { if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; _betta1 = beta1; _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); + if (step == _step) { return; } + if (step == _step + 1) { _step = step; - } else { _betta1_t *= _betta1; _betta2_t *= _betta2; + return; } } + + // Recompute moments from the explicit step when: + // - this is the first observed step for the native optimizer state, + // - betas changed, or + // - step progression is non-sequential (resume/rollback/reset or skipped updates). + _step = step; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); } inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) { diff --git a/csrc/xpu/includes/cpu_adam.h b/csrc/xpu/includes/cpu_adam.h index 7bc0364c569d..9ff124f41747 100644 --- a/csrc/xpu/includes/cpu_adam.h +++ b/csrc/xpu/includes/cpu_adam.h @@ -63,22 +63,25 @@ class Adam_Optimizer { inline void IncrementStep(size_t step, float beta1, float beta2) { if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; _betta1 = beta1; _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); + if (step == _step) { return; } + if (step == _step + 1) { _step = step; - } else { _betta1_t *= _betta1; _betta2_t *= _betta2; + return; } } + + // Recompute moments from the explicit step when: + // - this is the first observed step for the native optimizer state, + // - betas changed, or + // - step progression is non-sequential (resume/rollback/reset or skipped updates). + _step = step; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); } inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) { diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index d83b1732e700..6032518ccd3a 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -312,3 +312,59 @@ def test_multiple_subgroups(self): optimizer.rollback_subgroup(0) assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented" assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged" + + def test_step_subgroup_same_step_idempotent_across_subgroups(self): + """Repeated same-step subgroup updates should remain bit-identical.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step_subgroup(0) + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step_subgroup(1) + + assert optimizer.state[0]['step'] == logical_step + assert optimizer.state[1]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg']) + assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq']) + + def test_step_same_step_idempotent_across_param_keys(self): + """Repeated optimizer.step() with swapped param keys should be deterministic.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step() + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step() + + assert optimizer.state[param_a]['step'] == logical_step + assert optimizer.state[param_b]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg']) + assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq'])