Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace transformer apply_rotary_pos_emb with triton version #21

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Qubitium
Copy link

This is a port of https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files by @aljungberg to this repo.

On my 4090 4bit + group-size:512 + true-sequential 30b model inference test I saw about 8-10% speed up for new tokens/s (excluding prompt) in my own limited testing depending on input size. Did not see any adverse effects or drastic change in output versus non-triton rotary.

Added note that tl.libdevice is getting deprecated and refractored to tl.math in triton 2.1.0. I tried to add dynamic switching code but triton JIT does not allow this.

@fpgaminer Please test this branch for kernel math accuracy vs main.

@aljungberg
Copy link

aljungberg commented May 16, 2023

Note that in my testing there was little benefit to auto-tuning this kernel -- for any input size and model parameter count, roughly the same settings given here were the best or very close. That said, the one thing I didn't vary was hardware, so there could be other GPU models that benefit from a different block size.

Did not see any adverse effects or drastic change in output versus non-triton rotary.

This is one of those rare cases where faster isn't worse. These embeddings are (slightly) more accurate than the original!

@Qubitium
Copy link
Author

@aljungberg I tried to change as little as possible for the port and noticed a diff where

https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files#diff-6e5c6a701250dbeadf3480830a752403c9e485c8e093bc5977af4319ff12c53cR160

attn_weights is always returning None in your def forward change.

Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.

@fpgaminer
Copy link
Owner

I recall using just @torch.jit.script on the HEAD transformers apply_rotary_pos_emb and getting a similar 10% speed-up. Do we know if the Triton implementation beats that? If the performance is similar I'd rather lean on PyTorch to do the optimization rather than implementing custom Triton kernels for everything, which require on-going maintenance.

I also haven't tried torch.compile on a quantized model yet; not sure if it handles Triton kernel calls. That might have a similar benefit if it works.

@aljungberg
Copy link

@aljungberg I tried to change as little as possible for the port and noticed a diff where

https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files#diff-6e5c6a701250dbeadf3480830a752403c9e485c8e093bc5977af4319ff12c53cR160

attn_weights is always returning None in your def forward change.

Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.

Ah, yeah that's not related to my change. I just noticed attn_weights was either None or undefined before my change. Probably whoever made the switch to torch optimised attention (x = F.scaled_dot_product_attention(...) made a mistake. But that wasn't what I was working on so I just eliminated the exception.

@Qubitium
Copy link
Author

Qubitium commented May 17, 2023

Comparison using just a single input/sample for inference on all the variations discussed here:

diff in new tokens per/s on 30b 4bit group:512 model

  1. main/baseline
  2. triton rotary (this pr): +6.4%
  3. main + torch.compile: +0.4%
  4. main + @torch.jit.script apply_rotary: +1.5%
  5. main + @torch.jit.script on def apply_rotary & half_rotate: +1.6%
  6. main + torch.compile + @torch.jit.script on def apply_rotary+rotate_half : +2.0%

So both compile and @torch.jit.script do improve main but each has very small incremental improvement. Need to combine all for net 2%.

  1. triton rotary (this pr) + torch.compile: +6.4%
  2. triton rotary (this pr) + torch.compile + openapi/triton 2.1.0 head: +7.5% (need to change tl.libdevice to tl.math

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants