-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace transformer apply_rotary_pos_emb with triton version #21
base: main
Are you sure you want to change the base?
Conversation
Note that in my testing there was little benefit to auto-tuning this kernel -- for any input size and model parameter count, roughly the same settings given here were the best or very close. That said, the one thing I didn't vary was hardware, so there could be other GPU models that benefit from a different block size.
This is one of those rare cases where faster isn't worse. These embeddings are (slightly) more accurate than the original! |
@aljungberg I tried to change as little as possible for the port and noticed a diff where
Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks. |
I recall using just I also haven't tried |
Ah, yeah that's not related to my change. I just noticed |
Comparison using just a single input/sample for inference on all the variations discussed here: diff in new tokens per/s on 30b 4bit group:512 model
So both compile and @torch.jit.script do improve main but each has very small incremental improvement. Need to combine all for net 2%.
|
This is a port of https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files by @aljungberg to this repo.
On my 4090 4bit + group-size:512 + true-sequential 30b model inference test I saw about 8-10% speed up for new tokens/s (excluding prompt) in my own limited testing depending on input size. Did not see any adverse effects or drastic change in output versus non-triton rotary.
Added note that
tl.libdevice
is getting deprecated and refractored totl.math
in triton 2.1.0. I tried to add dynamic switching code but triton JIT does not allow this.@fpgaminer Please test this branch for kernel math accuracy vs main.