-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update FP8 scale-inverse in kernels with FP8 output #1083
Update FP8 scale-inverse in kernels with FP8 output #1083
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci |
Signed-off-by: Tim Moon <[email protected]>
/te-ci |
Use quantization scaling factor in ONNX quantize op. Signed-off-by: Tim Moon <[email protected]>
/te-ci |
Signed-off-by: Tim Moon <[email protected]>
/te-ci |
fp8_meta["scaling_fwd"].scale_inv, | ||
tex.FP8FwdTensors.GEMM1_INPUT, | ||
inputmat_scale_inv, | ||
0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just remove this item?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly to keep the API backward-compatible. LayerNormMLP
is still storing scale-invs in the fp8_meta
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow - this particular call is from internal autograd function, so we should be able to change its API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8_gemm
is used differently in Linear
and LayerNormMLP
: Linear
constructs a new scale-inv tensor, LayerNormMLP
still uses the fp8_meta
's scale-inv and requires an offset. I avoided touching the more complicated logic in LayerNormMLP
and attention to keep this PR simple.
Signed-off-by: Tim Moon <[email protected]>
/te-ci |
Signed-off-by: Tim Moon <[email protected]>
/te-ci |
* Perform scale-inv update in cast-transpose kernels Signed-off-by: Tim Moon <[email protected]> * Perform scale-inv update in cast and activation kernels Signed-off-by: Tim Moon <[email protected]> * Perform sclae-inv update in LayerNorm and RMSNorm kernels Signed-off-by: Tim Moon <[email protected]> * Perform scale-inv update after FP8 GEMMs Signed-off-by: Tim Moon <[email protected]> * Fuse casts and scale-inv updates in linear module Signed-off-by: Tim Moon <[email protected]> * Fuse casts and scale-inv updates in layernorm-linear module Signed-off-by: Tim Moon <[email protected]> * Simplify kernel to update FP8 scale-inv Signed-off-by: Tim Moon <[email protected]> * Fix typos Signed-off-by: Tim Moon <[email protected]> * Debug amax update in layernorm kernels Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug test failures Signed-off-by: Tim Moon <[email protected]> * Debug ONNX export Use quantization scaling factor in ONNX quantize op. Signed-off-by: Tim Moon <[email protected]> * Review suggestion from @ptrendx Signed-off-by: Tim Moon <[email protected]> * Debug mismatched dtypes Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: beinggod <[email protected]>
This is to accommondate the behavior change from: NVIDIA/TransformerEngine#1083
Description
We currently treat the FP8 scale-inverse (the dequantization scaling factor) as part of the FP8 recipe, along with the FP8 scale (the quantization scaling factor) and the absmax history. However, this is uncomfortable because any change to the FP8 recipe will invalidate the corresponding FP8 data. We work around this by creating copies of the scale-invs whenever there might be a recipe update, e.g. in between the forward and backward passes of the linear layer:
TransformerEngine/transformer_engine/pytorch/module/linear.py
Line 318 in 6717554
This adds non-trivial CPU overhead (I estimate ~20% for the PyTorch linear layer forward pass on an L40).
A better approach is to treat the scale-inv as part of the FP8 data, something that should be output along with the FP8 bits and should never change independently of the FP8 bits. The FP8 recipe tells us how we want to cast into FP8, while the scale-inv tells us how to convert back to higher precision. Note that this generalizes nicely to block-scaling schemes, where the scale-inv tensor may be large and must be packaged with the data during communication.
This PR makes initial work toward this scheme by including scale-inv updates in most of the kernels with FP8 output: casting, activations, LayerNorm, RMSNorm. It doesn't seem that cuBLAS supports this, so I've added a small kernel that is launched after FP8 GEMMs. I have not attempted to propagate this change into Userbuffers or attention. I've also updated the PyTorch
Linear
andLayerNormLinear
modules to avoid maintaining extra copies of the scale-inv and I see a 1.12x speedup in theLinear
forward pass.I'm a little apprehensive since this is technically a breaking change. Every time we generate FP8 values we will overwrite the FP8 recipe scale-inv. I have a hard time imagining why we would ever use a stale FP8 scale-inv though if the FP8 data has already been overwritten.
Type of change
Changes
Linear
moduleLayerNormLinear
moduleChecklist: