Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,25 @@ class Adam_Optimizer {
inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
if (step == _step) { return; }
if (step == _step + 1) {
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
return;
}
}

// Recompute moments from the explicit step when:
// - this is the first observed step for the native optimizer state,
// - betas changed, or
// - step progression is non-sequential (resume/rollback/reset or skipped updates).
_step = step;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
}
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
{
Expand Down
19 changes: 11 additions & 8 deletions csrc/xpu/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,25 @@ class Adam_Optimizer {
inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
if (step == _step) { return; }
if (step == _step + 1) {
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
return;
}
}

// Recompute moments from the explicit step when:
// - this is the first observed step for the native optimizer state,
// - betas changed, or
// - step progression is non-sequential (resume/rollback/reset or skipped updates).
_step = step;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
}
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
{
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/ops/adam/test_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,59 @@ def test_multiple_subgroups(self):
optimizer.rollback_subgroup(0)
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"

def test_step_subgroup_same_step_idempotent_across_subgroups(self):
"""Repeated same-step subgroup updates should remain bit-identical."""
from deepspeed.ops.adam import DeepSpeedCPUAdam

model_size = 128
steps = 4
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
param_a = torch.nn.Parameter(base.clone())
param_b = torch.nn.Parameter(base.clone())

optimizer = DeepSpeedCPUAdam([param_a])
for logical_step in range(1, steps + 1):
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)

optimizer.param_groups[0]['params'] = [param_a]
param_a.grad = grad.clone()
optimizer.step_subgroup(0)

optimizer.param_groups[0]['params'] = [param_b]
param_b.grad = grad.clone()
optimizer.step_subgroup(1)

assert optimizer.state[0]['step'] == logical_step
assert optimizer.state[1]['step'] == logical_step
assert torch.equal(param_a.data, param_b.data)
assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg'])
assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq'])

def test_step_same_step_idempotent_across_param_keys(self):
"""Repeated optimizer.step() with swapped param keys should be deterministic."""
from deepspeed.ops.adam import DeepSpeedCPUAdam

model_size = 128
steps = 4
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
param_a = torch.nn.Parameter(base.clone())
param_b = torch.nn.Parameter(base.clone())

optimizer = DeepSpeedCPUAdam([param_a])
for logical_step in range(1, steps + 1):
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)

optimizer.param_groups[0]['params'] = [param_a]
param_a.grad = grad.clone()
optimizer.step()

optimizer.param_groups[0]['params'] = [param_b]
param_b.grad = grad.clone()
optimizer.step()

assert optimizer.state[param_a]['step'] == logical_step
assert optimizer.state[param_b]['step'] == logical_step
assert torch.equal(param_a.data, param_b.data)
assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg'])
assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq'])