Skip to content

Commit

Permalink
Handle multiple dimensions in metal QMM + two fixes. (#2097)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Apr 20, 2024
1 parent 9215e9c commit dd78422
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
35 changes: 20 additions & 15 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ impl QMetalStorage {
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
let m = match dst_shape.len() {
3 => dst_shape[0] * dst_shape[1],
2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
Expand All @@ -166,18 +166,23 @@ impl QMetalStorage {
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
// In some cases it would be better to use the mm variant, though it has its drawbacks
// around memory alignemnt.
for batch_id in 0..m {
candle_metal_kernels::call_quantized_matmul_mv_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(1, 1, n, k),
storage.buffer(),
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
&self.buffer,
batch_id * n * DType::F32.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
}
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
Expand Down
15 changes: 8 additions & 7 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,7 @@ pub enum GgmlDType {
}

#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_t(
pub fn call_quantized_matmul_mv_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
Expand All @@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t(
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
dst_offset: usize,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
Expand Down Expand Up @@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t(
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
// https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
Expand Down Expand Up @@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t(
(
rhs,
(lhs, lhs_offset),
output,
(dst, dst_offset),
ne00,
ne01,
ne02,
Expand All @@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t(
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);

encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();
Expand Down

0 comments on commit dd78422

Please sign in to comment.