diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index fa2ebda56..0d4d93831 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -236,7 +236,7 @@ fn mul_mat_via_q8_1( if y.len() != y_rows * y_cols { crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len()) } - if x_cols != y_cols { + if x_cols != y_rows { crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}") } let k = x_cols; @@ -245,7 +245,7 @@ fn mul_mat_via_q8_1( let y_size_in_bytes = k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; - quantize_q8_1(y, &mut y_q8_1, k, y_rows, dev)?; + quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?; let kernel_name = match dtype { GgmlDType::Q4_0 => "mul_mat_q4_0", @@ -261,9 +261,9 @@ fn mul_mat_via_q8_1( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(x_rows * y_rows).w()? }; + let dst = unsafe { dev.alloc::(x_rows * y_cols).w()? }; let cfg = cudarc::driver::LaunchConfig { - grid_dim: ((x_rows as u32 + 127) / 128, (y_rows as u32 + 63) / 64, 1), + grid_dim: ((x_rows as u32 + 127) / 128, (y_cols as u32 + 63) / 64, 1), block_dim: (WARP_SIZE as u32, 4, 1), shared_mem_bytes: 0, }; @@ -274,9 +274,9 @@ fn mul_mat_via_q8_1( /* dst */ &dst, /* ncols_x */ x_cols as i32, /* nrows_x */ x_rows as i32, - /* ncols_y */ k_padded as i32, - /* nrows_y */ y_rows as i32, - /* nrows_dst */ y_rows as i32, + /* ncols_y */ y_cols as i32, + /* nrows_y */ k_padded as i32, + /* nrows_dst */ x_rows as i32, ); unsafe { func.launch(cfg, params) }.w()?; Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) @@ -451,8 +451,8 @@ impl QCudaStorage { self.dtype, /* x_rows */ n, /* x_cols */ k, - /* y_rows */ m, - /* y_cols */ k, + /* y_rows */ k, + /* y_cols */ m, self.device(), )? }; @@ -548,8 +548,8 @@ mod test { /* dtype */ GgmlDType::Q4_0, /* x_rows */ 4, /* x_cols */ ncols, - /* y_rows */ 4, - /* y_cols */ ncols, + /* y_rows */ ncols, + /* y_cols */ 4, &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; @@ -558,8 +558,8 @@ mod test { assert_eq!(vs[0], 347604.0); assert_eq!(vs[1], 888153.06); // TODO: This is wrong. - assert_eq!(vs[4], 347604.0); - assert_eq!(vs[5], 888153.06); + assert_eq!(vs[4], 869780.7); + assert_eq!(vs[5], 2483145.0); Ok(()) } } diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index 71fa8f977..7f2581c12 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -71,8 +71,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * } -#define CUDA_USE_TENSOR_CORES - #define WARP_SIZE 32 #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)