Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 53 additions & 39 deletions modules/util/optimizer/adafactor_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@torch.no_grad()
def step_adafactor_parameter(self, p, group, i):
def step_adafactor_parameter(self, p, group, i, compile: bool):
if p.grad is None:
return
grad = p.grad
Expand Down Expand Up @@ -50,55 +50,64 @@ def step_adafactor_parameter(self, p, group, i):
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)

p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)

beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
eps = group["eps"][0]

exp_avg_sq_row = state.get("exp_avg_sq_row", None)
exp_avg_sq_col = state.get("exp_avg_sq_col", None)
exp_avg_sq = state.get("exp_avg_sq", None)
clip_threshold = group["clip_threshold"]
exp_avg = state.get("exp_avg", None)
beta1 = group["beta1"]
weight_decay = group["weight_decay"]

@torch.compile(fullgraph=True, disable=not compile)
def compiled_step():
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
rms = self._rms(p_data_fp32)
update = (grad ** 2) + eps

exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
if factored:
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / clip_threshold).clamp_(min=1.0))
update.mul_(lr)

update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg.mul_(beta1).add_(update, alpha=(1 - beta1))
update = exp_avg

if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if weight_decay != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-weight_decay * lr))

if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)

p_data_fp32.add_(-update)
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
copy_stochastic_(p, p_data_fp32)
elif p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
else:
assert p_data_fp32 is p

if p.dtype == torch.bfloat16 and self.stochastic_rounding:
copy_stochastic_(p, p_data_fp32)
elif p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
else:
assert p_data_fp32 is p
return rms
state["RMS"] = compiled_step()


@torch.no_grad()
def step_adafactor(self, closure=None):
def step_adafactor(self, compile: bool, closure=None):
"""
Performs a single optimization step

Expand All @@ -112,12 +121,17 @@ def step_adafactor(self, closure=None):

for group in self.param_groups:
for i, p in enumerate(group["params"]):
step_adafactor_parameter(self, p, group, i)
step_adafactor_parameter(self, p, group, i, compile)

return loss


def patch_adafactor(optimizer: Adafactor, stochastic_rounding: bool):
def patch_adafactor(optimizer: Adafactor, stochastic_rounding: bool, compile: bool=True):
optimizer.stochastic_rounding = stochastic_rounding
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
optimizer.step_parameter = step_adafactor_parameter.__get__(optimizer, Adafactor)
#lambdas don't work because of scheduler patching:
def step(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this and the below step_parameter function get decorators like @functools.wraps(step_adafactor) to make these wrappers (slightly more) seamless?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate? what's the problem with the current code?

step_adafactor(*args, **kwargs, compile=compile)
optimizer.step = step.__get__(optimizer, Adafactor)
def step_parameter(*args, **kwargs):
step_adafactor_parameter(*args, **kwargs, compile=compile)
optimizer.step_parameter = step_parameter.__get__(optimizer, Adafactor)