Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Non-reentrant mode for activation recompute #670

Merged
merged 12 commits into from
Feb 24, 2024

Conversation

denera
Copy link
Collaborator

@denera denera commented Feb 15, 2024

Existing implementation for te.distributed.checkpoint() is hard-coded to mimic torch.utils.checkpoint.checkpoint(..., use_reentrant=True) and does not support use_reentrant=False mode. This requires at least one input tensor to the forward pass to have requires_grad=True enabled, which is not possible when the input to the checkpointed module is not a leaf node.

This PR implements support for use_reentrant=False using a pair of nested torch.autograd.graph.saved_tensor_hooks(pack, unpack) contexts. The logical sequence is like this:

  • The outer pack_hook(x) intercepts the ctx.save_for_backward(x, ...) calls in the forward pass. Here, we discard the activation tensors we would normally save and replace their nodes in the computation graph with integer indexes for a list of recomputed tensors (to be recomputed later).
  • Autograd engine triggers the outer unpack_hook(idx) to populate ctx.saved_tensors in the backward pass.
  • If the list of recomputed tensors is empty (idx==0), the outer unpack_hook(idx) triggers the forward recompute. Within the recompute, an inner pack_hook(x) interceptsthe ctx.save_for_backward(x, ...) calls to stash the detached activations into the recomputed tensors list.
  • Otherwise, if the activations have already been recomputed (idx >= 1) the outer unpack_hook(idx) simply returns the activation tensor from the index and clears it from the list of recomputed tensors.
  • The inner unpack_hook(idx) is never executed.

@denera denera added the enhancement New feature or request label Feb 15, 2024
@denera denera force-pushed the databricks/non-reentrant-checkpoint branch from 45ee94f to f4e6fce Compare February 16, 2024 15:26
…eckpointing in non-reentrant mode

Signed-off-by: Alp Dener <[email protected]>
@denera
Copy link
Collaborator Author

denera commented Feb 17, 2024

/te-ci pytorch

@ksivaman
Copy link
Member

For future reference, code is copied from native PyTorch non-re-entrant mode here. @denera Is there a reason we've changed the names for some of the implementations? E.g. _checkpoint_hook_CheckpointHook. If it's purely for naming, it might be better for us to make an exception here to be consistent with reference implementation.

@denera
Copy link
Collaborator Author

denera commented Feb 19, 2024

@ksivaman There's no particular reason for the naming beyond consistency with TE conventions (PascalCase for classes and snake_case for functions). I think making an exception here to remain consistent with original PyTorch source is a good idea. I'll make the change.

@denera denera self-assigned this Feb 21, 2024
@denera denera force-pushed the databricks/non-reentrant-checkpoint branch from 42eb675 to b1c4bb2 Compare February 21, 2024 01:39
@denera denera requested review from ksivaman and ptrendx February 21, 2024 01:39
@denera
Copy link
Collaborator Author

denera commented Feb 21, 2024

/te-ci pytorch

Comment on lines -140 to +144
@contextmanager
def activation_recompute_forward(
activation_recompute: bool = False,
recompute_phase: bool = False,
) -> None:
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The native PyTorch re-entrant checkpoint has an option to provide a context function that returns (forward_ctx, recompute_ctx) objects, which are then combined with the native _checkpoint_hook and _recomputation_hook within PyTorch. Re-implementing our activation_recompute_forward() as a class instead of a function made this workflow easier/cleaner for me.

Of course this PR does not use the native PyTorch checkpoint, but I've confirmed in my limited testing that it does work when supplied with the right context function. The caveat is that you have to set the RNG states once in the beginning correctly for all the devices, and then make sure the modules never tamper with the RNG state themselves. The preserve_rng_state=True option in PyTorch's native checkpointing makes sure that the initially correct RNG states are preserved through the checkpoint and recompute.

I kept this out of the PR because:

  1. It does not work with TE modules/models that use CudaRNGStateManager. Supporting that requires TE to implement its own non-reentrant checkpoint, which I've done in this PR.
  2. TE modules that implement custom forward/backward ops do not benefit from the early stopping feature in the PyTorch checkpoint because they call ctx.save_for_backward() only once, with all the saved tensors passed in bulk. The internal bookkeeping for early stopping relies on this being called separately for each tensor that needs to be saved.

Since I already did this conversion, I left it in this PR in preparation for the future possibility that we might figure out a way to eliminate CudaRNGStateManager and perhaps restructure how we use ctx.save_for_backward() in custom forward/backward ops. That would let us completely get rid of TE's own checkpoint implementation and just use PyTorch's native API.

@denera
Copy link
Collaborator Author

denera commented Feb 22, 2024

/te-ci pytorch

@denera
Copy link
Collaborator Author

denera commented Feb 23, 2024

@ksivaman @ptrendx The new non-reentrant checkpoint has a CI failure on test_numerics.py::test_gpt_full_activation_recompute with float32 dtype, no fp8, and batch size 1 only. The exact same test is passing with float16 and bfloat16 types, and the float32 type passes batch size 2 test too. Very odd, and unfortunately I haven't been able to reproduce it either with manual testing on the same nodes.

This particular test was actually one of several batch size 1 failures for activation recompute in previous CI runs. @ptrendx guessed that they may be due to nondeterminism from the bias-GELU fusion and turning that off for non-reentrant checkpointing resolved all the failures except this specific one. So it looks like I'm still missing something here but I haven't been able to narrow it down.

Any ideas what might be happening here?

ksivaman and others added 2 commits February 23, 2024 00:16
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@denera
Copy link
Collaborator Author

denera commented Feb 23, 2024

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit 82bc797 into NVIDIA:main Feb 24, 2024
17 of 20 checks passed
@ksivaman ksivaman added the 1.5.0 label Feb 24, 2024
@ptrendx
Copy link
Member

ptrendx commented Feb 24, 2024

For the record - I don't agree with merging this at this point. The change to pass the test is not applicable to most workloads using recomputation and the reason of the failure is not really understood. I was talking to @denera offline and we have some more ideas for debugging the underlying problem - please open a follow-up PR if you find the problem to be caused by the checkpointing logic.

@denera
Copy link
Collaborator Author

denera commented Feb 24, 2024

@ptrendx I didn't reach out to @ksivaman fast enough after our talk to hold off on merging. Sorry!

Fortunately, the previous reentrant checkpoint is still the default option for TE checkpointing so any existing app should not see any changes on their end.

In the meantime, I'm suspecting that the pack/unpack hooks in the non-reentrant checkpoint cause the recompute to recover the wrong tensor whenever the compile cache limit triggers a recompile of the hooks. I see the recomputed tensor counter getting reset when it shouldn't be. I think I have a solution for that and I will file a PR early next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.5.0 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants