-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Non-reentrant mode for activation recompute #670
[PyTorch] Non-reentrant mode for activation recompute #670
Conversation
Signed-off-by: Alp Dener <[email protected]>
… consistent with other TE API Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
45ee94f
to
f4e6fce
Compare
…eckpointing in non-reentrant mode Signed-off-by: Alp Dener <[email protected]>
/te-ci pytorch |
For future reference, code is copied from native PyTorch non-re-entrant mode here. @denera Is there a reason we've changed the names for some of the implementations? E.g. |
@ksivaman There's no particular reason for the naming beyond consistency with TE conventions ( |
…mentation Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
42eb675
to
b1c4bb2
Compare
Signed-off-by: Alp Dener <[email protected]>
/te-ci pytorch |
@contextmanager | ||
def activation_recompute_forward( | ||
activation_recompute: bool = False, | ||
recompute_phase: bool = False, | ||
) -> None: | ||
class activation_recompute_forward(AbstractContextManager, ContextDecorator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The native PyTorch re-entrant checkpoint has an option to provide a context function that returns (forward_ctx, recompute_ctx)
objects, which are then combined with the native _checkpoint_hook
and _recomputation_hook
within PyTorch. Re-implementing our activation_recompute_forward()
as a class instead of a function made this workflow easier/cleaner for me.
Of course this PR does not use the native PyTorch checkpoint, but I've confirmed in my limited testing that it does work when supplied with the right context function. The caveat is that you have to set the RNG states once in the beginning correctly for all the devices, and then make sure the modules never tamper with the RNG state themselves. The preserve_rng_state=True
option in PyTorch's native checkpointing makes sure that the initially correct RNG states are preserved through the checkpoint and recompute.
I kept this out of the PR because:
- It does not work with TE modules/models that use
CudaRNGStateManager
. Supporting that requires TE to implement its own non-reentrant checkpoint, which I've done in this PR. - TE modules that implement custom forward/backward ops do not benefit from the early stopping feature in the PyTorch checkpoint because they call
ctx.save_for_backward()
only once, with all the saved tensors passed in bulk. The internal bookkeeping for early stopping relies on this being called separately for each tensor that needs to be saved.
Since I already did this conversion, I left it in this PR in preparation for the future possibility that we might figure out a way to eliminate CudaRNGStateManager
and perhaps restructure how we use ctx.save_for_backward()
in custom forward/backward ops. That would let us completely get rid of TE's own checkpoint implementation and just use PyTorch's native API.
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
/te-ci pytorch |
@ksivaman @ptrendx The new non-reentrant checkpoint has a CI failure on This particular test was actually one of several batch size 1 failures for activation recompute in previous CI runs. @ptrendx guessed that they may be due to nondeterminism from the bias-GELU fusion and turning that off for non-reentrant checkpointing resolved all the failures except this specific one. So it looks like I'm still missing something here but I haven't been able to narrow it down. Any ideas what might be happening here? |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…mparison Signed-off-by: Alp Dener <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
For the record - I don't agree with merging this at this point. The change to pass the test is not applicable to most workloads using recomputation and the reason of the failure is not really understood. I was talking to @denera offline and we have some more ideas for debugging the underlying problem - please open a follow-up PR if you find the problem to be caused by the checkpointing logic. |
@ptrendx I didn't reach out to @ksivaman fast enough after our talk to hold off on merging. Sorry! Fortunately, the previous reentrant checkpoint is still the default option for TE checkpointing so any existing app should not see any changes on their end. In the meantime, I'm suspecting that the pack/unpack hooks in the non-reentrant checkpoint cause the recompute to recover the wrong tensor whenever the compile cache limit triggers a recompile of the hooks. I see the recomputed tensor counter getting reset when it shouldn't be. I think I have a solution for that and I will file a PR early next week. |
Existing implementation for
te.distributed.checkpoint()
is hard-coded to mimictorch.utils.checkpoint.checkpoint(..., use_reentrant=True)
and does not supportuse_reentrant=False
mode. This requires at least one input tensor to the forward pass to haverequires_grad=True
enabled, which is not possible when the input to the checkpointed module is not a leaf node.This PR implements support for
use_reentrant=False
using a pair of nestedtorch.autograd.graph.saved_tensor_hooks(pack, unpack)
contexts. The logical sequence is like this:pack_hook(x)
intercepts thectx.save_for_backward(x, ...)
calls in the forward pass. Here, we discard the activation tensors we would normally save and replace their nodes in the computation graph with integer indexes for a list of recomputed tensors (to be recomputed later).unpack_hook(idx)
to populatectx.saved_tensors
in the backward pass.idx==0
), the outerunpack_hook(idx)
triggers the forward recompute. Within the recompute, an innerpack_hook(x)
interceptsthectx.save_for_backward(x, ...)
calls to stash the detached activations into the recomputed tensors list.idx >= 1
) the outerunpack_hook(idx)
simply returns the activation tensor from the index and clears it from the list of recomputed tensors.unpack_hook(idx)
is never executed.