Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed Comparison: BitLinear and nn.Linear #118

Open
ZiqingChang opened this issue Aug 1, 2024 · 11 comments
Open

Speed Comparison: BitLinear and nn.Linear #118

ZiqingChang opened this issue Aug 1, 2024 · 11 comments

Comments

@ZiqingChang
Copy link

Hello,

I measured the time of your BitLinear and BitLinearBitBLAS against nn.Linear, and it seems that the time for smaller input_features and out_features is slower than nn.Linear. Is there a solution for this?

I used the quant_utils from your BitNet integration: https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py

My GPU is NVIDIA GeForce RTX 3090

import time
import torch
import torch.nn as nn
from torch.autograd import Variable

# from bitlinear import BitLinear, BitLinearBitBLAS
from utils_quant import BitLinear, BitLinearBitBLAS

# Function to measure computation time
def measure_time(layer, input_tensor, num_runs=100):
    with torch.no_grad():
        # Warm up
        for _ in range(100): ## 100
            _ = layer(input_tensor)
        
        start_time = time.time()
        for _ in range(num_runs):
            _ = layer(input_tensor)
        torch.cuda.synchronize()  #### new
        end_time = time.time()
        
        avg_time = (end_time - start_time) / num_runs
    return avg_time

# # # Test parameters
input_features = 512
output_features = 256
batch_size = 8

# input_features = 1024
# output_features = 512
# batch_size = 32

# input_features = 10240
# output_features = 5120
# batch_size = 32

# input_features = 20480
# output_features = 10240
# batch_size = 32


# Create random input tensor
input_tensor = torch.randn(batch_size, input_features).cuda()

# Initialize layers
nn_linear_layer = nn.Linear(input_features, output_features).cuda()
bit_linear_layer = BitLinear(input_features, output_features).cuda()

bitblas_linear_layer = BitLinearBitBLAS.from_bit_linear(bit_linear_layer)

# Measure computation time
num_runs = 100
nn_linear_time = measure_time(nn_linear_layer, input_tensor, num_runs)
bit_linear_time = measure_time(bit_linear_layer, input_tensor, num_runs)
bitblas_linear_time = measure_time(bitblas_linear_layer, input_tensor, num_runs)

print('input_features, output_features, batch_size: ', input_features, output_features, batch_size)
print(f"Average computation time for nn.Linear: {nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for fp32 simulated BitLinear: {bit_linear_time * 1000:.4f} ms")
print(f"Average computation time Bitblas BitLinear: {bitblas_linear_time * 1000:.4f} ms")


Here are the testing results:

input_features, output_features, batch_size:  512 256 8
Average computation time for nn.Linear: 0.0230 ms
Average computation time for fp32 simulated BitLinear: 0.3450 ms
Average computation time Bitblas BitLinear: 0.3091 ms

input_features, output_features, batch_size:  1024 512 32
Average computation time for nn.Linear: 0.0265 ms
Average computation time for fp32 simulated BitLinear: 0.3427 ms
Average computation time Bitblas BitLinear: 0.3137 ms

input_features, output_features, batch_size:  10240 5120 32
Average computation time for nn.Linear: 0.5421 ms
Average computation time for fp32 simulated BitLinear: 6.3314 ms
Average computation time Bitblas BitLinear: 0.3170 ms

input_features, output_features, batch_size:  20480 10240 32
Average computation time for nn.Linear: 2.1726 ms
Average computation time for fp32 simulated BitLinear: 25.2509 ms
Average computation time Bitblas BitLinear: 0.5633 ms

Thanks for your reply in advance.

@LeiWang1999
Copy link
Contributor

Hi, @ZiqingChang , I guess the overhead comes from input quant and output rescaling, we can improve the performance through torch.compile, this update is not applied in this version, to test the efficient bitnet layer you may want to check out the branch at https://github.com/LeiWang1999/vllm-bitblas/tree/bitblas-intg , we provide a more efficient bitlinear.

@ZiqingChang
Copy link
Author

ZiqingChang commented Aug 1, 2024

Hi, @LeiWang1999 ,thanks for your quick reply, I was just wondering, will BitLinearBitBLAS speed up models with the majority of linear layers with input < 1024, because based on the computational time of the overhead for small inputs, I'm not certain.

I was wondering about this because I tested the BitLinearBitBLAS on my model with most of the linear layers having inputs and outputs < 1024, and the inference time is 2x slower than the same model with nn.Linear.

@LeiWang1999
Copy link
Contributor

Hi @ZiqingChang , absolutely it should be.

@ZiqingChang
Copy link
Author

Hi @LeiWang1999 , can you provide an example of how to test the speed of your more efficient BitLinear? I tried torch.compile, but it does not seem to speed up BitBlas BitLinear:


opt_nn_linear_layer = torch.compile(nn_linear_layer, mode="reduce-overhead")
opt_bit_linear_layer = torch.compile(bit_linear_layer, mode="reduce-overhead")
opt_bitblas_linear_layer = torch.compile(bitblas_linear_layer, mode="reduce-overhead")

torch.set_float32_matmul_precision('high')

# Measure computation time
num_runs = 100
opt_nn_linear_time = measure_time(opt_nn_linear_layer, input_tensor, num_runs)
opt_bit_linear_time = measure_time(opt_bit_linear_layer, input_tensor, num_runs)
opt_bitblas_linear_time = measure_time(opt_bitblas_linear_layer, input_tensor, num_runs)

print('*' * 50)
print(f"Average computation time for compiled nn.Linear: {opt_nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for compiled fp32 simulated BitLinear: {opt_bit_linear_time * 1000:.4f} ms")
print(f"Average computation time for compiled Bitblas BitLinear: {opt_bitblas_linear_time * 1000:.4f} ms")


The results are :

input_features, output_features, batch_size:  512 256 8
**************************************************
Average computation time for nn.Linear: 0.0221 ms
Average computation time for fp32 simulated BitLinear: 0.3338 ms
Average computation time for Bitblas BitLinear: 0.3025 ms
**************************************************
Average computation time for compiled nn.Linear: 0.1108 ms
Average computation time for compiled fp32 simulated BitLinear: 0.1114 ms
Average computation time for compiled Bitblas BitLinear: 0.3155 ms

@LeiWang1999
Copy link
Contributor

@ZiqingChang , sry I missed this message! it's strange, would you mind provide an intact script that we can reproduce it :)

@ZiqingChang
Copy link
Author

Hello @LeiWang1999 , yes, here is the full script.

import time
import torch
import torch.nn as nn
from torch.autograd import Variable

# from bitlinear import BitLinear, BitLinearBitBLAS
from utils_quant import BitLinear, BitLinearBitBLAS

# Function to measure computation time
def measure_time(layer, input_tensor, num_runs=100):
    with torch.no_grad():
        # Warm up
        for _ in range(100): ## 100
            _ = layer(input_tensor)
        
        start_time = time.time()
        for _ in range(num_runs):
            _ = layer(input_tensor)
        torch.cuda.synchronize()  #### new
        end_time = time.time()
        
        avg_time = (end_time - start_time) / num_runs
    return avg_time

# # # Test parameters
input_features = 512
output_features = 256
batch_size = 8

# Create random input tensor
input_tensor = torch.randn(batch_size, input_features).cuda()

# Initialize layers
nn_linear_layer = nn.Linear(input_features, output_features).cuda()
bit_linear_layer = BitLinear(input_features, output_features).cuda()

bitblas_linear_layer = BitLinearBitBLAS.from_bit_linear(bit_linear_layer)

# Measure computation time
num_runs = 100
nn_linear_time = measure_time(nn_linear_layer, input_tensor, num_runs)
bit_linear_time = measure_time(bit_linear_layer, input_tensor, num_runs)
bitblas_linear_time = measure_time(bitblas_linear_layer, input_tensor, num_runs)

print('input_features, output_features, batch_size: ', input_features, output_features, batch_size)
print(f"Average computation time for nn.Linear: {nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for fp32 simulated BitLinear: {bit_linear_time * 1000:.4f} ms")
print(f"Average computation time Bitblas BitLinear: {bitblas_linear_time * 1000:.4f} ms")


opt_nn_linear_layer = torch.compile(nn_linear_layer, mode="reduce-overhead")
opt_bit_linear_layer = torch.compile(bit_linear_layer, mode="reduce-overhead")
opt_bitblas_linear_layer = torch.compile(bitblas_linear_layer, mode="reduce-overhead")

torch.set_float32_matmul_precision('high')

# Measure computation time
num_runs = 100
opt_nn_linear_time = measure_time(opt_nn_linear_layer, input_tensor, num_runs)
opt_bit_linear_time = measure_time(opt_bit_linear_layer, input_tensor, num_runs)
opt_bitblas_linear_time = measure_time(opt_bitblas_linear_layer, input_tensor, num_runs)

print('*' * 50)
print(f"Average computation time for compiled nn.Linear: {opt_nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for compiled fp32 simulated BitLinear: {opt_bit_linear_time * 1000:.4f} ms")
print(f"Average computation time for compiled Bitblas BitLinear: {opt_bitblas_linear_time * 1000:.4f} ms")


@LeiWang1999
Copy link
Contributor

Take some info from profiling:

nn.Linear

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     60.1           456110        105    4343.9    4320.0      4319      5120         88.9  void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_16x6_tn_align4>(T1::Params)                
     32.9           250178        105    2382.6    2368.0      2367      2496         19.4  void splitKreduce_kernel<(int)32, (int)16, int, float, float, float, float, (bool)1, (bool)0, (bool…
      2.0            14816          2    7408.0    7408.0      7360      7456         67.9  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      0.9             6688          3    2229.3    2304.0      1984      2400        217.8  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
      0.8             5984          2    2992.0    2992.0      2304      3680        973.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AbsFunctor<float>, at::detail::A…
      0.5             4161          2    2080.5    2080.5      2080      2081          0.7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::reciprocal_kernel_cuda(at::Tenso…
      0.5             3969          2    1984.5    1984.5      1984      1985          0.7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
      0.5             3936          1    3936.0    3936.0      3936      3936          0.0  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.5             3488          1    3488.0    3488.0      3488      3488          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.4             2911          1    2911.0    2911.0      2911      2911          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.3             2401          1    2401.0    2401.0      2401      2401          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::round_kernel_cuda(at::TensorIter…
      0.3             2400          1    2400.0    2400.0      2400      2400          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
      0.3             2241          1    2241.0    2241.0      2241      2241          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnSelf_add<signed cha…

BitBLASLinear

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     58.4           446394        210    2125.7    2128.0      1951      2463        159.0  triton_                                                                                             
     34.6           264545        105    2519.5    2528.0      2495      2624         19.5  matmul_n256k512_i8xi2_simt_opt_m_16                                                                 
      1.9            14784          2    7392.0    7392.0      7360      7424         45.3  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      0.9             6689          3    2229.7    2304.0      1985      2400        217.3  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
      0.8             6304          2    3152.0    3152.0      2304      4000       1199.3  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AbsFunctor<float>, at::detail::A…
      0.6             4225          2    2112.5    2112.5      2112      2113          0.7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::reciprocal_kernel_cuda(at::Tenso…
      0.5             3967          2    1983.5    1983.5      1983      1984          0.7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
      0.5             3936          1    3936.0    3936.0      3936      3936          0.0  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.5             3456          1    3456.0    3456.0      3456      3456          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.4             2912          1    2912.0    2912.0      2912      2912          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.3             2400          1    2400.0    2400.0      2400      2400          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(a…
      0.3             2400          1    2400.0    2400.0      2400      2400          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::round_kernel_cuda(at::TensorIter…
      0.3             2208          1    2208.0    2208.0      2208      2208          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnSelf_add<signed cha…

Looks like the shape that you're benchmarking is too small (especially N and K, and while M is smaller, the gain is bigger), make the overhead of activation dequantize and quantize be non-neligible.

@LeiWang1999
Copy link
Contributor

The kernel time is similar, must be something slow of the cpu runtime, we should dig further.

@LeiWang1999
Copy link
Contributor

I tested on a big shape, the performance is normal, though looks like there still exist some cpu runtime overhead.

BitBLAS Tuning done, appended operator to global_operator_cache.
input_features, output_features, batch_size:  16384 16384 1
Average computation time for nn.Linear: 0.6810 ms
Average computation time for fp32 simulated BitLinear: 12.7460 ms
Average computation time Bitblas BitLinear: 0.1746 ms

@ZiqingChang
Copy link
Author

For my model, the input and output size of Linear modules are all <= 1024, so due to the small shape, can BitBlas actually speed up the inference time?

@LeiWang1999
Copy link
Contributor

For my model, the input and output size of Linear modules are all <= 1024, so due to the small shape, can BitBlas actually speed up the inference time?

@ZiqingChang , from the kernel side, absolutely it should be, but we need to understand why the torch bitnet linear forward is introducing some extra overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants