Skip to content

Commit

Permalink
[PyTorch] Fix backward compatibility with checkpoint API (#740)
Browse files Browse the repository at this point in the history
* Fix backward compatibility with checkpoint API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments and fix lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Apr 3, 2024
1 parent 297459b commit 35a8754
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,40 @@ def checkpoint(
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
only_tensor_args = True
for arg in args:
if not isinstance(arg, torch.Tensor):
only_tensor_args = False
break

# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

# Ensure backward compatibility.
if not only_tensor_args:
warnings.warn(
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning, stacklevel=2,
)
assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API."
assert (
isinstance(args[0], bool) and callable(args[1])
and isinstance(args[2], None | dist_group_type)
), "Incorrect arguments for deprecated `checkpoint` API."
for arg in args[3:]:
assert (
isinstance(arg, None | torch.Tensor)
), f"Expected tensor argument, found {type(arg)}."

distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
args = args[3:]

# Trigger the native PyTorch checkpoint if:
# 1. `function` is a `torch.nn.Module`
# AND
Expand Down Expand Up @@ -555,16 +582,6 @@ def checkpoint(
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

# Make sure at least one tensor input has `requires_grad=True`
input_requires_grad = False
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
input_requires_grad = True
break
assert input_requires_grad, (
"`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
)

return _CheckpointFunction.apply(
function,
distribute_saved_activations,
Expand Down

0 comments on commit 35a8754

Please sign in to comment.