Skip to content

Commit

Permalink
Fix by transposing the rhs matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 14, 2024
1 parent b23eed8 commit c81ad77
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
26 changes: 13 additions & 13 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ fn mul_mat_via_q8_1(
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_cols {
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
Expand All @@ -245,7 +245,7 @@ fn mul_mat_via_q8_1(
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_rows, dev)?;
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;

let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_q4_0",
Expand All @@ -261,9 +261,9 @@ fn mul_mat_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_rows).w()? };
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((x_rows as u32 + 127) / 128, (y_rows as u32 + 63) / 64, 1),
grid_dim: ((x_rows as u32 + 127) / 128, (y_cols as u32 + 63) / 64, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
Expand All @@ -274,9 +274,9 @@ fn mul_mat_via_q8_1(
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ k_padded as i32,
/* nrows_y */ y_rows as i32,
/* nrows_dst */ y_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
Expand Down Expand Up @@ -451,8 +451,8 @@ impl QCudaStorage {
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ m,
/* y_cols */ k,
/* y_rows */ k,
/* y_cols */ m,
self.device(),
)?
};
Expand Down Expand Up @@ -548,8 +548,8 @@ mod test {
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ 4,
/* y_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
Expand All @@ -558,8 +558,8 @@ mod test {
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
// TODO: This is wrong.
assert_eq!(vs[4], 347604.0);
assert_eq!(vs[5], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
Ok(())
}
}
2 changes: 0 additions & 2 deletions candle-kernels/src/quantized.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
}


#define CUDA_USE_TENSOR_CORES

#define WARP_SIZE 32
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)

Expand Down

0 comments on commit c81ad77

Please sign in to comment.