Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize quantization process with QTensor::quantize_onto #2408

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,23 @@ impl QCudaStorage {
Ok(())
}

pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> {
// Run the quantization on cpu.
let src_len = src.as_slice::<f32>()?.len();
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;

if let QStorage::Cpu(storage) = &mut qcpu_storage {
storage.from_float(src.as_slice::<f32>()?)?;
} else {
unreachable!()
}

let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}

pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/quantized/dummy_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn storage_size_in_bytes(&self) -> usize {
0
}
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/quantized/dummy_metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ impl QMetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}

pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn storage_size_in_bytes(&self) -> usize {
0
}
Expand Down
17 changes: 17 additions & 0 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ impl QMetalStorage {
Ok(())
}

pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let elem_count = src.as_slice::<f32>()?.len();
let src = crate::Storage::Cpu(src);
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;

if let QStorage::Cpu(storage) = &mut qcpu_storage {
storage.from_float(src.as_slice::<f32>()?)?;
} else {
unreachable!()
}

let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}

pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
Expand Down
42 changes: 41 additions & 1 deletion candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,19 @@ impl QStorage {
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
_ => crate::bail!("Invalid quantize storage locations do not match"),
}
Ok(())
}

fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
(QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
_ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
}
Ok(())
}
Expand Down Expand Up @@ -341,6 +353,34 @@ impl QTensor {
})
}

/// Quantize `src` (currently on the CPU) to a QTensor on `dev`
pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
if !src.device().is_cpu() {
crate::bail!(
"`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
src.device()
)
}
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
// storage is on the `dev`, src is on `cpu`
let mut storage = dev.qzeros(elem_count, dtype)?;
storage.quantize_onto(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}

pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
Expand Down
Loading