-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] cuda graph support #575
Conversation
46b509a
to
bd7fd0a
Compare
5d5e52c
to
8cb93ff
Compare
f4c8b6f
to
374867a
Compare
def _reset_caches(self) -> None: | ||
"""Reset cached values | ||
|
||
Should be called after any in-place operation. | ||
|
||
""" | ||
self._transpose = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the automatic cache clearing makes using the transpose cache a much more manual and dangerous process. Consider something like:
matmul(x, w.transpose(0, 1))
w -= learning_rate * w.grad
matmul(x, w.transpose(0, 1))
Previously we could just set update_cache="lazy"
. Now there needs to be manual logic to figure out the microbatch step, or else it will provide the stale values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, caching is not used, so a fresh transpose will be computed each time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If caching is used, it is reasonable to expect the user to know when to reuse a cached value and when to force recompute. This is consistent with our design of is_first_microbatch
argument to the forward for module APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: we use 2 args cache
and update_cache
to support this logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're overfitting to the Linear
weight use-case. For example, in #707 I want to pass Float8Tensor
s between ops as inputs or dgrads:
class DbiasCastTranspose:
def backward(self, dy):
db = dy.sum(dim=0)
dx = cast_transpose(dy) # Creates Float8Tensor with transpose cache
return dx, db
class FP8Linear: # Part of FP8 attention
def backward(self, dy):
if not isinstance(dy, Float8Tensor):
dy = Float8Tensor.to_float8(dy)
dx = Float8Tensor(...) # No transpose cache
fp8_gemm(w.transpose()._data, dy.transpose()._data, out=dx._data)
dw = fp8_gemm(x, dy)
return dx, dw
FP8Linear
has no idea where its input came from. Maybe it's from DbiasCastTranspose
(Float8Tensor
with cached transpose), FP8Linear
(Float8Tensor
without cached transpose), or a non-FP8 op. Our current approach with lazy transpose caching gives us a lot of flexibility and I think we should abandon it only when really necessary.
I suppose this is not precisely relevant since it doesn't involve in-place operations, but a more general statement about the design of Float8Tensor
.
transformer_engine/common/include/transformer_engine/cast_transpose_noop.h
Show resolved
Hide resolved
d0aa61c
to
bb5b4d6
Compare
/te-ci |
/te-ci pytorch |
/te-ci pytorch |
eff5d27
to
32e070c
Compare
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]>
db6a812
to
31dc133
Compare
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
9944150
to
3c50a17
Compare
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
/te-ci pytorch |
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]>
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
This PR adds the following features (high-level):
make_graphed_callables
API similar to the PyTorch API with some additional arguments for FP8 usage. Support for fp8 weight caching via existingis_first_microbatch
argument is also retained.Float8Tensor
that makes the transposes persistent for graph capture. Also fixes use cases for the vanilla optimizers (non fp8-distopt).