diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index 32ae422b8..fbd70852a 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -401,7 +401,40 @@ mod tests { #[cfg(not(feature = "cuda"))] #[test] - fn test_bitpack() { + fn test_bitpack_8bit() { + use crate::HqqBits; + use candle_core::{Device, Tensor}; + let bits = HqqBits::Eight; + let device = &Device::Cpu; + let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap(); + let c = bits.bitpack_type()(wq.clone()) + .unwrap() + .to_vec2::() + .unwrap(); + assert_eq!(c, [[1, 2], [3, 4], [255, 0]]); + } + + #[cfg(all(feature = "cuda"))] + #[test] + fn test_bitpack_8bit() { + use crate::HqqBits; + use candle_core::DType; + use candle_core::{Device, Tensor}; + let bits = HqqBits::Eight; + let device = Device::new_cuda(0).unwrap(); + let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap(); + let c = bits.bitpack_type()(wq.clone()) + .unwrap() + .to_dtype(DType::U8) + .unwrap() + .to_vec2::() + .unwrap(); + assert_eq!(c, [[1, 2], [3, 4], [255, 0]]); + } + + #[cfg(not(feature = "cuda"))] + #[test] + fn test_bitpack_4bit() { use crate::HqqBits; use candle_core::{Device, Tensor}; let bits = HqqBits::Four; @@ -416,7 +449,7 @@ mod tests { #[cfg(all(feature = "cuda"))] #[test] - fn test_bitpack() { + fn test_bitpack_4bit() { use crate::HqqBits; use candle_core::{Device, Tensor}; let bits = HqqBits::Four;