-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed Comparison: BitLinear and nn.Linear #118
Comments
Hi, @ZiqingChang , I guess the overhead comes from input quant and output rescaling, we can improve the performance through torch.compile, this update is not applied in this version, to test the efficient bitnet layer you may want to check out the branch at https://github.com/LeiWang1999/vllm-bitblas/tree/bitblas-intg , we provide a more efficient bitlinear. |
Hi, @LeiWang1999 ,thanks for your quick reply, I was just wondering, will BitLinearBitBLAS speed up models with the majority of linear layers with input < 1024, because based on the computational time of the overhead for small inputs, I'm not certain. I was wondering about this because I tested the BitLinearBitBLAS on my model with most of the linear layers having inputs and outputs < 1024, and the inference time is 2x slower than the same model with nn.Linear. |
Hi @ZiqingChang , absolutely it should be. |
Hi @LeiWang1999 , can you provide an example of how to test the speed of your more efficient BitLinear? I tried torch.compile, but it does not seem to speed up BitBlas BitLinear:
The results are :
|
@ZiqingChang , sry I missed this message! it's strange, would you mind provide an intact script that we can reproduce it :) |
Hello @LeiWang1999 , yes, here is the full script.
|
Take some info from profiling: nn.Linear Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
60.1 456110 105 4343.9 4320.0 4319 5120 88.9 void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_16x6_tn_align4>(T1::Params)
32.9 250178 105 2382.6 2368.0 2367 2496 19.4 void splitKreduce_kernel<(int)32, (int)16, int, float, float, float, float, (bool)1, (bool)0, (bool…
2.0 14816 2 7408.0 7408.0 7360 7456 67.9 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
0.9 6688 3 2229.3 2304.0 1984 2400 217.8 void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
0.8 5984 2 2992.0 2992.0 2304 3680 973.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::AbsFunctor<float>, at::detail::A…
0.5 4161 2 2080.5 2080.5 2080 2081 0.7 void at::native::vectorized_elementwise_kernel<(int)4, at::native::reciprocal_kernel_cuda(at::Tenso…
0.5 3969 2 1984.5 1984.5 1984 1985 0.7 void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
0.5 3936 1 3936.0 3936.0 3936 3936 0.0 void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
0.5 3488 1 3488.0 3488.0 3488 3488 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
0.4 2911 1 2911.0 2911.0 2911 2911 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
0.3 2401 1 2401.0 2401.0 2401 2401 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::round_kernel_cuda(at::TensorIter…
0.3 2400 1 2400.0 2400.0 2400 2400 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
0.3 2241 1 2241.0 2241.0 2241 2241 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnSelf_add<signed cha… BitBLASLinear Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
58.4 446394 210 2125.7 2128.0 1951 2463 159.0 triton_
34.6 264545 105 2519.5 2528.0 2495 2624 19.5 matmul_n256k512_i8xi2_simt_opt_m_16
1.9 14784 2 7392.0 7392.0 7360 7424 45.3 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
0.9 6689 3 2229.7 2304.0 1985 2400 217.3 void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
0.8 6304 2 3152.0 3152.0 2304 4000 1199.3 void at::native::vectorized_elementwise_kernel<(int)4, at::native::AbsFunctor<float>, at::detail::A…
0.6 4225 2 2112.5 2112.5 2112 2113 0.7 void at::native::vectorized_elementwise_kernel<(int)4, at::native::reciprocal_kernel_cuda(at::Tenso…
0.5 3967 2 1983.5 1983.5 1983 1984 0.7 void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
0.5 3936 1 3936.0 3936.0 3936 3936 0.0 void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
0.5 3456 1 3456.0 3456.0 3456 3456 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
0.4 2912 1 2912.0 2912.0 2912 2912 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
0.3 2400 1 2400.0 2400.0 2400 2400 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
0.3 2400 1 2400.0 2400.0 2400 2400 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::round_kernel_cuda(at::TensorIter…
0.3 2208 1 2208.0 2208.0 2208 2208 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnSelf_add<signed cha… Looks like the shape that you're benchmarking is too small (especially N and K, and while M is smaller, the gain is bigger), make the overhead of activation dequantize and quantize be non-neligible. |
The kernel time is similar, must be something slow of the cpu runtime, we should dig further. |
I tested on a big shape, the performance is normal, though looks like there still exist some cpu runtime overhead.
|
For my model, the input and output size of Linear modules are all <= 1024, so due to the small shape, can BitBlas actually speed up the inference time? |
@ZiqingChang , from the kernel side, absolutely it should be, but we need to understand why the torch bitnet linear forward is introducing some extra overhead. |
Hello,
I measured the time of your BitLinear and BitLinearBitBLAS against nn.Linear, and it seems that the time for smaller input_features and out_features is slower than nn.Linear. Is there a solution for this?
I used the quant_utils from your BitNet integration: https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py
My GPU is NVIDIA GeForce RTX 3090
Here are the testing results:
Thanks for your reply in advance.
The text was updated successfully, but these errors were encountered: