Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,9 @@ def wrapper(param, i, j):
def accumulate_hp_grads_and_remove_lp(*notneeded):
self.accumulate_hp_grads_and_remove_lp(param, i, j)

self._grad_acc_hooks.append(register_grad_hook(param, accumulate_hp_grads_and_remove_lp))
hook_handle = register_grad_hook(param, accumulate_hp_grads_and_remove_lp)
if hook_handle is not None:
self._grad_acc_hooks.append(hook_handle)

wrapper(param, i, j)

Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,9 @@ def reduce_partition_and_remove_grads(*notneeded):
non_leaf_params_requiring_grad) + leaf_module_count
self.update_hook_state_and_maybe_run_epilogue(current_expected)

self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))
hook_handle = register_grad_hook(param, reduce_partition_and_remove_grads)
if hook_handle is not None:
self._grad_acc_hooks.append(hook_handle)

if not z3_leaf_parameter(param):
wrapper(param)
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,9 @@ def grad_handling_hook(*notneeded):
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
self.update_hook_state_and_maybe_run_epilogue(current_expected)

self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))
hook_handle = register_grad_hook(param, grad_handling_hook)
if hook_handle is not None:
self._grad_acc_hooks.append(hook_handle)

wrapper(param, i)

Expand Down
55 changes: 51 additions & 4 deletions deepspeed/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

# DeepSpeed Team

import logging

from packaging import version as pkg_version

import torch

logger = logging.getLogger(__name__)
_legacy_fallback_logged = False


def required_torch_version(min_version=None, max_version=None):
assert min_version or max_version, "Must provide a min_version or max_version argument"
Expand All @@ -22,13 +27,55 @@ def required_torch_version(min_version=None, max_version=None):
return True


def _log_legacy_grad_hook_fallback_once():
global _legacy_fallback_logged
if _legacy_fallback_logged:
return

logger.warning(
"Falling back to param.register_hook for gradient hook registration "
"because no grad accumulator node is available for this parameter."
)
_legacy_fallback_logged = True


def _get_grad_accumulator_for_legacy_hook(param):
# On older torch versions we rely on traversing the autograd edge created
# by expand_as to reach the parameter's AccumulateGrad node.
try:
param_tmp = param.expand_as(param)
except Exception:
return None

grad_fn = getattr(param_tmp, "grad_fn", None)
if grad_fn is None:
return None

next_functions = getattr(grad_fn, "next_functions", None)
if not next_functions:
return None

return next_functions[0][0]


def _register_param_hook_fallback(param, hook):
def wrapper(grad):
hook(param)
return grad

return param.register_hook(wrapper)


def register_grad_hook(param, hook):
if required_torch_version(min_version=2.1):
return param.register_post_accumulate_grad_hook(hook)
else:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)

grad_acc = _get_grad_accumulator_for_legacy_hook(param)
if grad_acc is None:
_log_legacy_grad_hook_fallback_once()
return _register_param_hook_fallback(param, hook)

return grad_acc.register_hook(hook)


def jit_script_compat(fn):
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/runtime/test_register_grad_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import types

from deepspeed.utils import torch as ds_torch_utils


def test_register_grad_hook_uses_post_accumulate_hook(monkeypatch):
monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: True)

recorded = {}

class DummyParam:

def register_post_accumulate_grad_hook(self, hook):
recorded["hook"] = hook
return "post_acc_handle"

handle = ds_torch_utils.register_grad_hook(DummyParam(), lambda *_args: None)

assert handle == "post_acc_handle"
assert "hook" in recorded


def test_register_grad_hook_uses_legacy_grad_accumulator(monkeypatch):
monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: False)

recorded = {}

class DummyGradAccumulator:

def register_hook(self, hook):
recorded["hook"] = hook
return "grad_acc_handle"

grad_acc = DummyGradAccumulator()

class DummyParam:

def expand_as(self, _param):
return types.SimpleNamespace(
grad_fn=types.SimpleNamespace(
next_functions=((grad_acc, None), ),
))

def register_hook(self, _hook):
raise AssertionError("legacy param hook fallback should not be used")

handle = ds_torch_utils.register_grad_hook(DummyParam(), lambda *_args: None)

assert handle == "grad_acc_handle"
assert "hook" in recorded


def test_register_grad_hook_falls_back_when_grad_accumulator_missing(monkeypatch):
monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: False)
monkeypatch.setattr(ds_torch_utils, "_legacy_fallback_logged", False)

recorded = {}
invoked = {}

class DummyParam:

def expand_as(self, _param):
return types.SimpleNamespace(grad_fn=None)

def register_hook(self, hook):
recorded["hook"] = hook
return "param_hook_handle"

param = DummyParam()

def on_grad(_param):
invoked["called"] = True
invoked["param"] = _param

handle = ds_torch_utils.register_grad_hook(param, on_grad)

assert handle == "param_hook_handle"
assert "hook" in recorded

grad = object()
assert recorded["hook"](grad) is grad
assert invoked["called"] is True
assert invoked["param"] is param