Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] cuda graph support #575

Merged
merged 8 commits into from
Apr 12, 2024
Merged

Conversation

ksivaman
Copy link
Member

@ksivaman ksivaman commented Dec 22, 2023

This PR adds the following features (high-level):

  • make_graphed_callables API similar to the PyTorch API with some additional arguments for FP8 usage. Support for fp8 weight caching via existing is_first_microbatchargument is also retained.
  • Restructuring and amax reduction logic with a simpler design and handling of various parallelisms with minimal book-keeping compared to the previous approach.
  • Forward and backward amaxes are reduced within the scope of current iteration, solving numerous bugs w.r.t. checkpointing and removing the need to save global buffers.
  • Support for nested/multiple FP8 autocast contexts with different recipes and distributed groups.
  • Amax reductions are module independent and happen at at autocast level. This also resolves numerous bugs and allows for support for MoE/LoRA like models.
  • Redesign of transposes for Float8Tensor that makes the transposes persistent for graph capture. Also fixes use cases for the vanilla optimizers (non fp8-distopt).
  • The scaling inverses for weight tensors are no longer frozen when caching weights across microbatches.

@ksivaman ksivaman marked this pull request as draft December 22, 2023 14:08
@timmoon10 timmoon10 self-requested a review March 11, 2024 22:29
Comment on lines 543 to 549
def _reset_caches(self) -> None:
"""Reset cached values

Should be called after any in-place operation.

"""
self._transpose = None
Copy link
Collaborator

@timmoon10 timmoon10 Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the automatic cache clearing makes using the transpose cache a much more manual and dangerous process. Consider something like:

matmul(x, w.transpose(0, 1))
w -= learning_rate * w.grad
matmul(x, w.transpose(0, 1))

Previously we could just set update_cache="lazy". Now there needs to be manual logic to figure out the microbatch step, or else it will provide the stale values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example, caching is not used, so a fresh transpose will be computed each time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If caching is used, it is reasonable to expect the user to know when to reuse a cached value and when to force recompute. This is consistent with our design of is_first_microbatch argument to the forward for module APIs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we use 2 args cache and update_cache to support this logic.

Copy link
Collaborator

@timmoon10 timmoon10 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're overfitting to the Linear weight use-case. For example, in #707 I want to pass Float8Tensors between ops as inputs or dgrads:

class DbiasCastTranspose:
    def backward(self, dy):
        db = dy.sum(dim=0)
        dx = cast_transpose(dy)  # Creates Float8Tensor with transpose cache
        return dx, db

class FP8Linear:  # Part of FP8 attention
    def backward(self, dy):
        if not isinstance(dy, Float8Tensor):
           dy = Float8Tensor.to_float8(dy)
        dx = Float8Tensor(...)  # No transpose cache
        fp8_gemm(w.transpose()._data, dy.transpose()._data, out=dx._data)
        dw = fp8_gemm(x, dy)
        return dx, dw

FP8Linear has no idea where its input came from. Maybe it's from DbiasCastTranspose (Float8Tensor with cached transpose), FP8Linear (Float8Tensor without cached transpose), or a non-FP8 op. Our current approach with lazy transpose caching gives us a lot of flexibility and I think we should abandon it only when really necessary.

I suppose this is not precisely relevant since it doesn't involve in-place operations, but a more general statement about the design of Float8Tensor.

@ksivaman ksivaman marked this pull request as ready for review March 23, 2024 05:42
@ksivaman ksivaman changed the title [WIP ] PyTorch FP8 cuda graphs [PyTorch] cuda graph support Mar 23, 2024
@ksivaman
Copy link
Member Author

/te-ci

@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman
Copy link
Member Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator

timmoon10 commented Mar 27, 2024

#735 has some improvements to the Float8Tensor transpose function, which should reduce the divergence with #707. If there are no issues, we should merge that branch into this PR.

@ksivaman ksivaman marked this pull request as draft March 27, 2024 23:59
@ptrendx ptrendx added the 1.6.0 label Apr 2, 2024
transformer_engine/pytorch/fp8.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/fp8.py Show resolved Hide resolved
transformer_engine/pytorch/module/base.py Show resolved Hide resolved
transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
tests/pytorch/test_cuda_graphs.py Show resolved Hide resolved
@ksivaman
Copy link
Member Author

ksivaman commented Apr 9, 2024

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman ksivaman merged commit 73f8d90 into NVIDIA:main Apr 12, 2024
18 of 20 checks passed
@ksivaman ksivaman mentioned this pull request Apr 30, 2024
4 tasks
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 9, 2024
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants