Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding #668

Merged
merged 14 commits into from
Mar 6, 2024

Conversation

Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Feb 15, 2024

This PR modifies the logic of MHA and DPA by using InferenceParams (KV-cache), and extends the supported shapes of matrices in unfused softmax to rectangular to enable speculative decoding.

  • Added test cases into test_numerics.py to check if the outputs of TransformerLayer or MultiheadAttention generated in a single full run match with those generated incrementally using KV-cache
  • The longer the input sequence, the larger the divergence of the output due to accumulation of numerical errors in a forward pass

@Oleg-Goncharov Oleg-Goncharov added the enhancement New feature or request label Feb 16, 2024
@Oleg-Goncharov Oleg-Goncharov changed the title [Pytorch] Adjusted the logic of MHA and DPA to enable speculative decoding [PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding Feb 16, 2024
@Oleg-Goncharov
Copy link
Collaborator Author

Oleg-Goncharov commented Feb 16, 2024

/te-ci

@ptrendx
Copy link
Member

ptrendx commented Feb 16, 2024

/te-ci pytorch

transformer_engine/pytorch/softmax.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/softmax.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/attention.py Show resolved Hide resolved
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
@timmoon10 timmoon10 self-requested a review February 17, 2024 01:08
Oleg-Goncharov and others added 3 commits February 19, 2024 17:47
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

Signed-off-by: Oleg Goncharov <[email protected]>
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@ksivaman
Copy link
Member

/te-ci pytorch

@ksivaman
Copy link
Member

ksivaman commented Mar 6, 2024

/te-ci pytorch

@ksivaman ksivaman merged commit b459ccc into NVIDIA:main Mar 6, 2024
20 checks passed
@Oleg-Goncharov Oleg-Goncharov deleted the pr_inference_params branch March 6, 2024 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants