Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FP8 scale-inverse in kernels with FP8 output #1083

Merged
merged 20 commits into from
Aug 21, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Aug 7, 2024

Description

We currently treat the FP8 scale-inverse (the dequantization scaling factor) as part of the FP8 recipe, along with the FP8 scale (the quantization scaling factor) and the absmax history. However, this is uncomfortable because any change to the FP8 recipe will invalidate the corresponding FP8 data. We work around this by creating copies of the scale-invs whenever there might be a recipe update, e.g. in between the forward and backward passes of the linear layer:

fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,

This adds non-trivial CPU overhead (I estimate ~20% for the PyTorch linear layer forward pass on an L40).

A better approach is to treat the scale-inv as part of the FP8 data, something that should be output along with the FP8 bits and should never change independently of the FP8 bits. The FP8 recipe tells us how we want to cast into FP8, while the scale-inv tells us how to convert back to higher precision. Note that this generalizes nicely to block-scaling schemes, where the scale-inv tensor may be large and must be packaged with the data during communication.

This PR makes initial work toward this scheme by including scale-inv updates in most of the kernels with FP8 output: casting, activations, LayerNorm, RMSNorm. It doesn't seem that cuBLAS supports this, so I've added a small kernel that is launched after FP8 GEMMs. I have not attempted to propagate this change into Userbuffers or attention. I've also updated the PyTorch Linear and LayerNormLinear modules to avoid maintaining extra copies of the scale-inv and I see a 1.12x speedup in the Linear forward pass.

I'm a little apprehensive since this is technically a breaking change. Every time we generate FP8 values we will overwrite the FP8 recipe scale-inv. I have a hard time imagining why we would ever use a stale FP8 scale-inv though if the FP8 data has already been overwritten.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Update FP8 scale-inverse in cast-transpose kernels
  • Update FP8 scale-inverse in cast and activation kernels
  • Update FP8 scale-inverse in LayerNorm and RMSNorm kernels
  • Update FP8 scale-inverse after FP8 GEMMs
  • Avoid unnecessary FP8 scale-inverse copies in PyTorch Linear module
  • Avoid unnecessary FP8 scale-inverse copies in PyTorch LayerNormLinear module

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10 timmoon10 requested a review from denera August 7, 2024 03:03
Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci

timmoon10 and others added 2 commits August 8, 2024 17:14
Use quantization scaling factor in ONNX quantize op.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10
Copy link
Collaborator Author

/te-ci

fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
inputmat_scale_inv,
0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just remove this item?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly to keep the API backward-compatible. LayerNormMLP is still storing scale-invs in the fp8_meta.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow - this particular call is from internal autograd function, so we should be able to change its API.

Copy link
Collaborator Author

@timmoon10 timmoon10 Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_gemm is used differently in Linear and LayerNormMLP: Linear constructs a new scale-inv tensor, LayerNormMLP still uses the fp8_meta's scale-inv and requires an offset. I avoided touching the more complicated logic in LayerNormMLP and attention to keep this PR simple.

@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10 timmoon10 merged commit 8e3561b into NVIDIA:main Aug 21, 2024
31 checks passed
BeingGod pushed a commit to BeingGod/TransformerEngine that referenced this pull request Aug 30, 2024
* Perform scale-inv update in cast-transpose kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform scale-inv update in cast and activation kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform scale-inv update after FP8 GEMMs

Signed-off-by: Tim Moon <[email protected]>

* Fuse casts and scale-inv updates in linear module

Signed-off-by: Tim Moon <[email protected]>

* Fuse casts and scale-inv updates in layernorm-linear module

Signed-off-by: Tim Moon <[email protected]>

* Simplify kernel to update FP8 scale-inv

Signed-off-by: Tim Moon <[email protected]>

* Fix typos

Signed-off-by: Tim Moon <[email protected]>

* Debug amax update in layernorm kernels

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.

Signed-off-by: Tim Moon <[email protected]>

* Review suggestion from @ptrendx

Signed-off-by: Tim Moon <[email protected]>

* Debug mismatched dtypes

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: beinggod <[email protected]>
wenchenvincent added a commit to ROCm/TransformerEngine that referenced this pull request Dec 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants