diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b28..6318d673c 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -449,6 +449,23 @@ impl QCudaStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let data = self.device.htod_sync_copy(data.as_ref()).w()?; + self.data = data; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.data.len() } diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index ca7b81208..69daad3cc 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -32,6 +32,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 520d0ed49..fc51214c1 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -28,6 +28,10 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac..5f61749b3 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -126,6 +126,23 @@ impl QMetalStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let src = crate::Storage::Cpu(src); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.buffer.length() as usize } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d852d5041..ff00f3638 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -101,7 +101,19 @@ impl QStorage { } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, - _ => crate::bail!("Invalid dequantize storage locations do not match"), + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_onto(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + _ => crate::bail!("Invalid quantize source storage locations: not on cpu"), } Ok(()) } @@ -341,6 +353,34 @@ impl QTensor { }) } + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_onto(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + pub fn dtype(&self) -> GgmlDType { self.storage.dtype() }