diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index cc5f9959f57a..3179f99006a2 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -551,7 +551,9 @@ def wrapper(param, i, j): def accumulate_hp_grads_and_remove_lp(*notneeded): self.accumulate_hp_grads_and_remove_lp(param, i, j) - self._grad_acc_hooks.append(register_grad_hook(param, accumulate_hp_grads_and_remove_lp)) + hook_handle = register_grad_hook(param, accumulate_hp_grads_and_remove_lp) + if hook_handle is not None: + self._grad_acc_hooks.append(hook_handle) wrapper(param, i, j) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4f2a19e7431a..53e81782bbb2 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1289,7 +1289,9 @@ def reduce_partition_and_remove_grads(*notneeded): non_leaf_params_requiring_grad) + leaf_module_count self.update_hook_state_and_maybe_run_epilogue(current_expected) - self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) + hook_handle = register_grad_hook(param, reduce_partition_and_remove_grads) + if hook_handle is not None: + self._grad_acc_hooks.append(hook_handle) if not z3_leaf_parameter(param): wrapper(param) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 005853ebffc0..d912ed96d8b7 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1051,7 +1051,9 @@ def grad_handling_hook(*notneeded): current_expected = count_used_parameters_in_backward(all_params_requiring_grad) self.update_hook_state_and_maybe_run_epilogue(current_expected) - self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) + hook_handle = register_grad_hook(param, grad_handling_hook) + if hook_handle is not None: + self._grad_acc_hooks.append(hook_handle) wrapper(param, i) diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index 8bad8208b91c..38380e90941e 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -3,10 +3,15 @@ # DeepSpeed Team +import logging + from packaging import version as pkg_version import torch +logger = logging.getLogger(__name__) +_legacy_fallback_logged = False + def required_torch_version(min_version=None, max_version=None): assert min_version or max_version, "Must provide a min_version or max_version argument" @@ -22,13 +27,55 @@ def required_torch_version(min_version=None, max_version=None): return True +def _log_legacy_grad_hook_fallback_once(): + global _legacy_fallback_logged + if _legacy_fallback_logged: + return + + logger.warning( + "Falling back to param.register_hook for gradient hook registration " + "because no grad accumulator node is available for this parameter." + ) + _legacy_fallback_logged = True + + +def _get_grad_accumulator_for_legacy_hook(param): + # On older torch versions we rely on traversing the autograd edge created + # by expand_as to reach the parameter's AccumulateGrad node. + try: + param_tmp = param.expand_as(param) + except Exception: + return None + + grad_fn = getattr(param_tmp, "grad_fn", None) + if grad_fn is None: + return None + + next_functions = getattr(grad_fn, "next_functions", None) + if not next_functions: + return None + + return next_functions[0][0] + + +def _register_param_hook_fallback(param, hook): + def wrapper(grad): + hook(param) + return grad + + return param.register_hook(wrapper) + + def register_grad_hook(param, hook): if required_torch_version(min_version=2.1): return param.register_post_accumulate_grad_hook(hook) - else: - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - return grad_acc.register_hook(hook) + + grad_acc = _get_grad_accumulator_for_legacy_hook(param) + if grad_acc is None: + _log_legacy_grad_hook_fallback_once() + return _register_param_hook_fallback(param, hook) + + return grad_acc.register_hook(hook) def jit_script_compat(fn): diff --git a/tests/unit/runtime/test_register_grad_hook.py b/tests/unit/runtime/test_register_grad_hook.py new file mode 100644 index 000000000000..5b15cfa1b17f --- /dev/null +++ b/tests/unit/runtime/test_register_grad_hook.py @@ -0,0 +1,83 @@ +import types + +from deepspeed.utils import torch as ds_torch_utils + + +def test_register_grad_hook_uses_post_accumulate_hook(monkeypatch): + monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: True) + + recorded = {} + + class DummyParam: + + def register_post_accumulate_grad_hook(self, hook): + recorded["hook"] = hook + return "post_acc_handle" + + handle = ds_torch_utils.register_grad_hook(DummyParam(), lambda *_args: None) + + assert handle == "post_acc_handle" + assert "hook" in recorded + + +def test_register_grad_hook_uses_legacy_grad_accumulator(monkeypatch): + monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: False) + + recorded = {} + + class DummyGradAccumulator: + + def register_hook(self, hook): + recorded["hook"] = hook + return "grad_acc_handle" + + grad_acc = DummyGradAccumulator() + + class DummyParam: + + def expand_as(self, _param): + return types.SimpleNamespace( + grad_fn=types.SimpleNamespace( + next_functions=((grad_acc, None), ), + )) + + def register_hook(self, _hook): + raise AssertionError("legacy param hook fallback should not be used") + + handle = ds_torch_utils.register_grad_hook(DummyParam(), lambda *_args: None) + + assert handle == "grad_acc_handle" + assert "hook" in recorded + + +def test_register_grad_hook_falls_back_when_grad_accumulator_missing(monkeypatch): + monkeypatch.setattr(ds_torch_utils, "required_torch_version", lambda **_kwargs: False) + monkeypatch.setattr(ds_torch_utils, "_legacy_fallback_logged", False) + + recorded = {} + invoked = {} + + class DummyParam: + + def expand_as(self, _param): + return types.SimpleNamespace(grad_fn=None) + + def register_hook(self, hook): + recorded["hook"] = hook + return "param_hook_handle" + + param = DummyParam() + + def on_grad(_param): + invoked["called"] = True + invoked["param"] = _param + + handle = ds_torch_utils.register_grad_hook(param, on_grad) + + assert handle == "param_hook_handle" + assert "hook" in recorded + + grad = object() + assert recorded["hook"](grad) is grad + assert invoked["called"] is True + assert invoked["param"] is param