Skip to content

Commit

Permalink
docstring cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Feb 16, 2024
1 parent 879b50b commit f4e6fce
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,32 +451,42 @@ def checkpoint(
function: Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool
distribute_saved_activations: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
get_rng_state_tracker: `Callable`
get_rng_state_tracker: `Callable`, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`.
tp_group : ProcessGroup
tensor parallel process group. Required only when `distribute_saved_activations=True`
and `use_reentrant=True`.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode
perform checkpointing in reentrant mode.
args : tuple
tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""

# Pop out te.distributed.checkpoint() arguments
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

# Pop out and discard te.utils.checkpoint.checkpoint() arguments for wrapper compatibility
# Pop out te.utils.checkpoint.checkpoint() arguments for wrapper compatibility
context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False)
if 'transformer_engine' not in function.__class__.__module__:
# This is not a TE module so just pass everything to the PyTorch native checkpoint
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=use_reentrant,
context_fn=context_fn,
determinism_check=determinism_check,
debug=debug,
**kwargs
)
del context_fn, determinism_check, debug

use_reentrant = kwargs.pop("use_reentrant", True)
Expand Down

0 comments on commit f4e6fce

Please sign in to comment.