-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
int4_weight_only
Slows Down torch.nn.Linear
for Llama2 7B Shapes
#1606
Comments
cc @HDCharles as @jerryzh168 is OOO |
cc @mostafaelhoushi What hardware, pytorch and ao version are you using? On my H100 on the nightlies, I see: (for 4096, 11008) which is a speedup
|
@jcaip I am running on NVIDIA A100-SXM4-80GB and using the following libraries:
Let me draft another script that benchmarks a Hugging Face model. I am worried that with that speedup you have for a kernel on H100 nightly, you get no speedup for Llama2 7B. |
I would definitely recommend using the latest nightlies to test |
Thanks @jcaip I installed torch nightly and torchao nightly. When benchmarking on the same A100 a single linear layer `(4096, 11008), I get:
and when benchmarking the Llama2 7B model I get:
|
I also see slowdowns on my A100, not sure of the exact cause. Maybe there were some changes to the int4 kernel in core? I also see you're running without compile, but I don't think that should make a difference for the single layer linear ... |
Just to add another datapoint. I am quite new to the library, I just came across yesterday while trying to quantize a large model for inference, but also noticed quite some slow down after quantization. I tried the script that is shared above for benchmarking, and this are the numbers I get in H100 also. # bs, input_dim, output_dim = 1, 4096, 11008
Baseline: 0.036133759021759033 ms
Quantized: 0.0388044810295105 ms
# bs, input_dim, output_dim = 16, 4096, 11008,
Baseline: 0.03751264095306397 ms
Quantized: 0.05798175811767578 ms |
On H100 I am seeing
On Nightly
|
No torch.compile, only change to your script was using this timer func: https://github.com/drisspg/transformer_nuggets/blob/46127c65fa72c338fb600dd0373cb7fc36bd9613/transformer_nuggets/utils/benchmark.py#L55 which is the closest I have found to what NCU would report for kernel time. So ignoring any CPU overhead. |
Thanks @drisspg . Indeed when I tried that timing function, I got big speedups. I want to verify, if it ignores any CPU overhead, does it mean it doesn't measure end-to-end speedup of execution? Will it be fair to use it to measure speedup for models rather than indiviudal kernels? |
@mostafaelhoushi That is really good question and like all great questions I kinda depends on what you care about. Often we are looking at changes to individual kernels so the above function is very helpful in that case. But if what you really care about is wall clock time I use https://github.com/drisspg/transformer_nuggets/blob/8b0a671b7b30cc7e186edd654f9c1565251b9b97/transformer_nuggets/utils/benchmark.py#L44 which should capture the cpu overhead and manually adds the cuda_syncs. Obviously the best thing the measure is the thing you actually care about but these act as proxies. I find that just looking at the pytorch traces is also really helpful |
I think it's likely due to L2 cache. Happens to everyone before when doing microbenchmarks 😅 (seems like inspired by |
I have created a small script to benchmark int4 quantization on A100 GPUs, with inputs that have batch size 1 and seqlen 1.
When I test weigh shapes that exist in Llama2 7B, I actually get a slow down:
When I use a really large shape that doesn't exist in Llama2 7B, I do get some speedup:
This is strange because gpt-fast uses a similar int4 quantization and gets 2x speedup on Llama2 7B.
The text was updated successfully, but these errors were encountered: