diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index c310d7668..f7f5b68ac 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -152,9 +152,9 @@ impl QMetalStorage { // We always use a single batch dimension and stack all the tensors in the batch on the // second dimension as the implementation in candle-metal-kernels doesn't handle batch // properly. - let (b, m) = match dst_shape.len() { - 3 => (1, dst_shape[0] * dst_shape[1]), - 2 => (1, dst_shape[0]), + let m = match dst_shape.len() { + 3 => dst_shape[0] * dst_shape[1], + 2 => dst_shape[0], n => crate::bail!("Invalid rank {n} for quantized matmul metal"), }; let last_k = dst_shape.pop().unwrap(); @@ -166,18 +166,23 @@ impl QMetalStorage { let device = storage.device().clone(); let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; let command_buffer = device.command_buffer()?; - candle_metal_kernels::call_quantized_matmul_t( - device.device(), - &command_buffer, - device.kernels(), - self.dtype.into(), - (b, m, n, k), - storage.buffer(), - layout.start_offset() * storage.dtype().size_in_bytes(), - &self.buffer, - &dst, - ) - .map_err(MetalError::from)?; + // In some cases it would be better to use the mm variant, though it has its drawbacks + // around memory alignemnt. + for batch_id in 0..m { + candle_metal_kernels::call_quantized_matmul_mv_t( + device.device(), + &command_buffer, + device.kernels(), + self.dtype.into(), + (1, 1, n, k), + storage.buffer(), + (layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(), + &self.buffer, + batch_id * n * DType::F32.size_in_bytes(), + &dst, + ) + .map_err(MetalError::from)?; + } let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); Ok((dst_storage, dst_shape)) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 78108127d..e05797a23 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1699,7 +1699,7 @@ pub enum GgmlDType { } #[allow(clippy::too_many_arguments)] -pub fn call_quantized_matmul_t( +pub fn call_quantized_matmul_mv_t( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, @@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t( lhs: &Buffer, lhs_offset: usize, rhs: &Buffer, - output: &Buffer, + dst_offset: usize, + dst: &Buffer, ) -> Result<(), MetalKernelError> { // Everything is in reverse let ne00 = k as i64; @@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t( } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML - let nth0 = 4; - let nth1 = 8; + // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576 + let nth0 = 2; + let nth1 = 32; let align = 4; (nth0, nth1, align) } @@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t( ( rhs, (lhs, lhs_offset), - output, + (dst, dst_offset), ne00, ne01, ne02, @@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t( r3 ) ); - encoder.set_threadgroup_memory_length(0, 8192); encoder.use_resource(lhs, metal::MTLResourceUsage::Read); encoder.use_resource(rhs, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.end_encoding();