Skip to content

Commit 34b303e

Browse files
authored
Refactor quantization tensor data representation (#2479)
* Remove quantization strategy from QFloat to use scheme instead (qparams unknown at compile-time) Instead, the qparams are stored in the TensorData bytes so we can pack/unpack them based on the scheme * Change quantization tensor data representation to pack quantized data type into u32 * Fix clippy * Remove comment * Add alloc vec import * Remove print * Rename into_bytes
1 parent 94db460 commit 34b303e

File tree

19 files changed

+485
-260
lines changed

19 files changed

+485
-260
lines changed

crates/burn-fusion/src/ops/qtensor.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Range};
22

33
use burn_tensor::{
44
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5-
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
5+
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
66
repr::{
77
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
88
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
@@ -21,14 +21,14 @@ use crate::{
2121
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
2222
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
2323
match data.dtype {
24-
DType::QFloat(strategy) => {
24+
DType::QFloat(scheme) => {
2525
let client = get_client::<B>(device);
2626
let tensor = B::q_from_data(data, device);
2727
let shape = B::q_shape(&tensor);
2828

2929
let handles = B::quantized_tensor_handle(tensor);
30-
let qparams = match strategy {
31-
QuantizationStrategy::PerTensorAffineInt8(_) => {
30+
let qparams = match scheme {
31+
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
3232
let offset = if let Some(offset) = handles.offset {
3333
offset
3434
} else {
@@ -49,7 +49,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
4949
)),
5050
}
5151
}
52-
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
52+
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
5353
assert!(
5454
handles.offset.is_none(),
5555
"Offset should not be provided for symmetric quantization."
@@ -74,7 +74,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
7474
QFusionTensor {
7575
qtensor,
7676
qparams,
77-
scheme: strategy.scheme(),
77+
scheme,
7878
}
7979
}
8080
_ => panic!(
@@ -142,7 +142,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
142142
scale: qparams.scale.clone().into_description(),
143143
offset: qparams.offset.clone().map(|x| x.into_description()),
144144
},
145-
scheme: scheme.clone(),
145+
scheme: *scheme,
146146
out: out.to_description_out(),
147147
};
148148

@@ -157,7 +157,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
157157

158158
QFusionTensor {
159159
qtensor: out,
160-
scheme: scheme.clone(),
160+
scheme: *scheme,
161161
qparams: qparams.into(),
162162
}
163163
}

crates/burn-fusion/src/stream/context.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
544544
.as_ref()
545545
.map(|x| x.to_relative(converter)),
546546
},
547-
scheme: desc.scheme.clone(),
547+
scheme: desc.scheme,
548548
out: desc.out.to_relative(converter),
549549
})
550550
}
@@ -561,7 +561,7 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
561561
.as_ref()
562562
.map(|x| x.to_relative(converter)),
563563
},
564-
scheme: desc.qtensor.scheme.clone(),
564+
scheme: desc.qtensor.scheme,
565565
},
566566
out: desc.out.to_relative(converter),
567567
})

crates/burn-fusion/src/tensor.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ impl<R: FusionRuntime> Clone for QFusionTensor<R> {
190190
fn clone(&self) -> Self {
191191
Self {
192192
qtensor: self.qtensor.clone(),
193-
scheme: self.scheme.clone(),
193+
scheme: self.scheme,
194194
qparams: self.qparams.clone(),
195195
}
196196
}

crates/burn-jit/src/kernel/quantization/quantize.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ where
214214

215215
QJitTensor {
216216
qtensor,
217-
scheme: scheme.clone(),
217+
scheme: *scheme,
218218
qparams,
219219
}
220220
}

crates/burn-jit/src/ops/qtensor.rs

+17-56
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
use std::ops::Range;
22

3-
use alloc::vec::Vec;
43
use burn_tensor::{
54
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
65
quantization::{
7-
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
8-
QuantizationStrategy, QuantizationType,
6+
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType,
97
},
10-
DType, Device, ElementConversion, Shape, TensorData,
8+
DType, Device, Shape, TensorData,
119
};
1210

1311
use crate::{
@@ -17,28 +15,14 @@ use crate::{
1715
};
1816
use cubecl::CubeElement;
1917

20-
fn pack_i8s_to_u32s(data: &TensorData) -> Vec<u32> {
21-
// Shift and combine groups of four 8-bit values into a u32.
22-
// Same as doing this:
23-
// let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF);
24-
data.as_bytes()
25-
.chunks(4)
26-
.map(|x| {
27-
x.iter().enumerate().fold(0u32, |acc, (i, x)| {
28-
acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8)
29-
})
30-
})
31-
.collect()
32-
}
33-
3418
/// Create a quantized tensor with packed values (u32).
3519
fn packed_tensor<R: JitRuntime, S: Into<Shape>>(
36-
data: Vec<u32>,
20+
data: &[u8],
3721
shape: S,
3822
device: &R::Device,
3923
) -> JitTensor<R, u32> {
4024
let client = R::client(device);
41-
let buffer = client.create(u32::as_bytes(&data));
25+
let buffer = client.create(data);
4226

4327
JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer)
4428
}
@@ -51,27 +35,21 @@ where
5135
{
5236
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
5337
match data.dtype {
54-
DType::QFloat(strategy) => match strategy {
55-
QuantizationStrategy::PerTensorAffineInt8(q) => {
38+
DType::QFloat(scheme) => match scheme {
39+
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
40+
| QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
5641
// Convert quantized values to packed u32s
42+
let qparams = data.get_q_params().unwrap();
5743
QJitTensor {
58-
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
59-
scheme: strategy.scheme(),
44+
qtensor: packed_tensor(data.values_as_bytes(), data.shape.clone(), device),
45+
scheme,
6046
qparams: JitQuantizationParameters::new(
61-
q.scale.elem(),
62-
Some(q.offset.elem()),
47+
qparams.scale,
48+
qparams.offset,
6349
device,
6450
),
6551
}
6652
}
67-
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
68-
// Convert quantized values to packed u32s
69-
QJitTensor {
70-
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
71-
scheme: strategy.scheme(),
72-
qparams: JitQuantizationParameters::new(q.scale.elem(), None, device),
73-
}
74-
}
7553
},
7654
_ => panic!(
7755
"Invalid dtype (expected DType::QFloat, got {:?})",
@@ -119,35 +97,18 @@ where
11997

12098
async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
12199
let strategy = tensor.strategy();
122-
let numel = tensor.qtensor.shape.num_elements();
123100
let qtensor = kernel::into_contiguous(tensor.qtensor);
124101

125102
let bytes = qtensor.client.read_async(qtensor.handle.binding()).await;
126103

127-
// Convert packed bytes to quantized dtype (TensorData can be used with other backends,
128-
// which don't have the prior knowledge of this packed representation)
104+
// TensorData keeps quantized values packed into 32-bit unsigned integers so we can
105+
// keep the current representation, just cast the bytes as u32.
129106
match &tensor.scheme {
130107
QuantizationScheme::PerTensorAffine(dtype)
131108
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
132-
QuantizationType::QInt8 => TensorData::quantized(
133-
u32::from_bytes(&bytes)
134-
.iter()
135-
.enumerate()
136-
.flat_map(|(i, packed)| {
137-
// A single u32 could contain less than four 8-bit values...
138-
let n = core::cmp::min(4, numel - i * 4);
139-
// Extract each 8-bit segment from u32 and cast back to i8
140-
// Same as doing this (when 4 values are fully packed):
141-
// let a = ((packed >> 24) & 0xFF) as i8;
142-
// let b = ((packed >> 16) & 0xFF) as i8;
143-
// let c = ((packed >> 8) & 0xFF) as i8;
144-
// let d = (packed & 0xFF) as i8;
145-
(0..n).map(move |i| (packed >> ((3 - i) * 8) & 0xFF) as i8)
146-
})
147-
.collect(),
148-
qtensor.shape,
149-
strategy,
150-
),
109+
QuantizationType::QInt8 => {
110+
TensorData::quantized(u32::from_bytes(&bytes).to_vec(), qtensor.shape, strategy)
111+
}
151112
},
152113
}
153114
}

crates/burn-jit/src/tensor/qtensor.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Clone for QJitTensor<R, F, I
6161
fn clone(&self) -> Self {
6262
Self {
6363
qtensor: self.qtensor.clone(),
64-
scheme: self.scheme.clone(),
64+
scheme: self.scheme,
6565
qparams: self.qparams.clone(),
6666
}
6767
}

0 commit comments

Comments
 (0)