Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4_weight_only Slows Down torch.nn.Linear for Llama2 7B Shapes #1606

Open
mostafaelhoushi opened this issue Jan 23, 2025 · 13 comments
Open

Comments

@mostafaelhoushi
Copy link

I have created a small script to benchmark int4 quantization on A100 GPUs, with inputs that have batch size 1 and seqlen 1.

When I test weigh shapes that exist in Llama2 7B, I actually get a slow down:

# input_dim, output_dim = 4096, 4096
Baseline:       0.023313920497894287 ms
Quantized:      0.08300095558166504 ms
# input_dim, output_dim = 4096, 11008
Baseline:       0.06082496166229248 ms
Quantized:      0.08460960388183594 ms
# input_dim, output_dim = 11008, 4096
Baseline:       0.059748477935791015 ms
Quantized:      0.09495231628417969 ms

When I use a really large shape that doesn't exist in Llama2 7B, I do get some speedup:

# input_dim, output_dim = 11008, 11008
Baseline:       0.14746272087097168 ms
Quantized:      0.09298111915588379 ms

This is strange because gpt-fast uses a similar int4 quantization and gets 2x speedup on Llama2 7B.

@vkuzo
Copy link
Contributor

vkuzo commented Jan 23, 2025

cc @HDCharles as @jerryzh168 is OOO

@jcaip
Copy link
Contributor

jcaip commented Jan 23, 2025

cc @mostafaelhoushi What hardware, pytorch and ao version are you using?

On my H100 on the nightlies, I see: (for 4096, 11008) which is a speedup

Baseline:       0.04709856033325195 ms
Quantized:      0.041566081047058105 ms

@mostafaelhoushi
Copy link
Author

@jcaip I am running on NVIDIA A100-SXM4-80GB and using the following libraries:

pytorch-triton==3.0.0+dedb7bdf33
torch==2.5.1+cu121

Let me draft another script that benchmarks a Hugging Face model. I am worried that with that speedup you have for a kernel on H100 nightly, you get no speedup for Llama2 7B.

@jcaip
Copy link
Contributor

jcaip commented Jan 23, 2025

I would definitely recommend using the latest nightlies to test

@mostafaelhoushi
Copy link
Author

mostafaelhoushi commented Jan 23, 2025

Thanks @jcaip
So the script to benchmark a Llama2 7B is here.

I installed torch nightly and torchao nightly. When benchmarking on the same A100 a single linear layer `(4096, 11008), I get:

Baseline:       0.06053599834442139 ms
Quantized:      0.09977215766906739 ms

and when benchmarking the Llama2 7B model I get:

Baseline:       Model: 40.05844970703125 ms, MLP Layer: 0.18576160430908203 ms
Quantized:      Model: 60.4076806640625 ms, MLP Layer: 0.38929054260253904 ms

@jcaip
Copy link
Contributor

jcaip commented Jan 23, 2025

I also see slowdowns on my A100, not sure of the exact cause. Maybe there were some changes to the int4 kernel in core? I also see you're running without compile, but I don't think that should make a difference for the single layer linear ...

@joanPlepi
Copy link

Just to add another datapoint. I am quite new to the library, I just came across yesterday while trying to quantize a large model for inference, but also noticed quite some slow down after quantization.

I tried the script that is shared above for benchmarking, and this are the numbers I get in H100 also.

# bs, input_dim, output_dim = 1, 4096, 11008
Baseline:       0.036133759021759033 ms
Quantized:      0.0388044810295105 ms

# bs, input_dim, output_dim = 16, 4096, 11008, 
Baseline:       0.03751264095306397 ms
Quantized:      0.05798175811767578 ms

@drisspg
Copy link
Contributor

drisspg commented Jan 25, 2025

On H100 I am seeing
On 2.6 release:

Baseline:       58.58537037037032 us
Quantized:      21.989801687764025 us

On Nightly

Baseline:       58.74621212121216 us
Quantized:      21.948585227273114 us

@mostafaelhoushi
Copy link
Author

Thanks @drisspg . You got those speedups without adding any torch.compile() statements?
I am interested in getting that speedup you got in eager mode.
Not sure how you got that ~3x speedup on H100 while @jcaip didn't.

@drisspg
Copy link
Contributor

drisspg commented Jan 25, 2025

No torch.compile, only change to your script was using this timer func: https://github.com/drisspg/transformer_nuggets/blob/46127c65fa72c338fb600dd0373cb7fc36bd9613/transformer_nuggets/utils/benchmark.py#L55

which is the closest I have found to what NCU would report for kernel time. So ignoring any CPU overhead.

@mostafaelhoushi
Copy link
Author

Thanks @drisspg . Indeed when I tried that timing function, I got big speedups.
It also lead to big speedups when I measured speedup on a whole model.

I want to verify, if it ignores any CPU overhead, does it mean it doesn't measure end-to-end speedup of execution? Will it be fair to use it to measure speedup for models rather than indiviudal kernels?

@drisspg
Copy link
Contributor

drisspg commented Jan 27, 2025

@mostafaelhoushi That is really good question and like all great questions I kinda depends on what you care about. Often we are looking at changes to individual kernels so the above function is very helpful in that case. But if what you really care about is wall clock time I use https://github.com/drisspg/transformer_nuggets/blob/8b0a671b7b30cc7e186edd654f9c1565251b9b97/transformer_nuggets/utils/benchmark.py#L44

which should capture the cpu overhead and manually adds the cuda_syncs. Obviously the best thing the measure is the thing you actually care about but these act as proxies. I find that just looking at the pytorch traces is also really helpful

@gau-nernst
Copy link
Collaborator

I think it's likely due to L2 cache. Happens to everyone before when doing microbenchmarks 😅

https://github.com/pytorch/pytorch/blob/0144613e6ff6e018ca41085d1509dcceb80987f7/torch/_inductor/utils.py#L150-L158

(seems like inspired bytriton.testing.do_bench, or maybe it's the other way round)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants