1
1
use std:: ops:: Range ;
2
2
3
- use alloc:: vec:: Vec ;
4
3
use burn_tensor:: {
5
4
ops:: { FloatTensor , IntTensor , QTensorOps , QuantizedTensor } ,
6
5
quantization:: {
7
- QTensorPrimitive , QuantizationParametersPrimitive , QuantizationScheme ,
8
- QuantizationStrategy , QuantizationType ,
6
+ QTensorPrimitive , QuantizationParametersPrimitive , QuantizationScheme , QuantizationType ,
9
7
} ,
10
- DType , Device , ElementConversion , Shape , TensorData ,
8
+ DType , Device , Shape , TensorData ,
11
9
} ;
12
10
13
11
use crate :: {
@@ -17,28 +15,14 @@ use crate::{
17
15
} ;
18
16
use cubecl:: CubeElement ;
19
17
20
- fn pack_i8s_to_u32s ( data : & TensorData ) -> Vec < u32 > {
21
- // Shift and combine groups of four 8-bit values into a u32.
22
- // Same as doing this:
23
- // let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF);
24
- data. as_bytes ( )
25
- . chunks ( 4 )
26
- . map ( |x| {
27
- x. iter ( ) . enumerate ( ) . fold ( 0u32 , |acc, ( i, x) | {
28
- acc | ( * x as i8 as u32 & 0xFF ) << ( ( 3 - i) * 8 )
29
- } )
30
- } )
31
- . collect ( )
32
- }
33
-
34
18
/// Create a quantized tensor with packed values (u32).
35
19
fn packed_tensor < R : JitRuntime , S : Into < Shape > > (
36
- data : Vec < u32 > ,
20
+ data : & [ u8 ] ,
37
21
shape : S ,
38
22
device : & R :: Device ,
39
23
) -> JitTensor < R , u32 > {
40
24
let client = R :: client ( device) ;
41
- let buffer = client. create ( u32 :: as_bytes ( & data) ) ;
25
+ let buffer = client. create ( data) ;
42
26
43
27
JitTensor :: new_contiguous ( client, device. clone ( ) , shape. into ( ) , buffer)
44
28
}
@@ -51,27 +35,21 @@ where
51
35
{
52
36
fn q_from_data ( data : TensorData , device : & Device < Self > ) -> QuantizedTensor < Self > {
53
37
match data. dtype {
54
- DType :: QFloat ( strategy) => match strategy {
55
- QuantizationStrategy :: PerTensorAffineInt8 ( q) => {
38
+ DType :: QFloat ( scheme) => match scheme {
39
+ QuantizationScheme :: PerTensorAffine ( QuantizationType :: QInt8 )
40
+ | QuantizationScheme :: PerTensorSymmetric ( QuantizationType :: QInt8 ) => {
56
41
// Convert quantized values to packed u32s
42
+ let qparams = data. get_q_params ( ) . unwrap ( ) ;
57
43
QJitTensor {
58
- qtensor : packed_tensor ( pack_i8s_to_u32s ( & data) , data. shape , device) ,
59
- scheme : strategy . scheme ( ) ,
44
+ qtensor : packed_tensor ( data. values_as_bytes ( ) , data. shape . clone ( ) , device) ,
45
+ scheme,
60
46
qparams : JitQuantizationParameters :: new (
61
- q . scale . elem ( ) ,
62
- Some ( q . offset . elem ( ) ) ,
47
+ qparams . scale ,
48
+ qparams . offset ,
63
49
device,
64
50
) ,
65
51
}
66
52
}
67
- QuantizationStrategy :: PerTensorSymmetricInt8 ( q) => {
68
- // Convert quantized values to packed u32s
69
- QJitTensor {
70
- qtensor : packed_tensor ( pack_i8s_to_u32s ( & data) , data. shape , device) ,
71
- scheme : strategy. scheme ( ) ,
72
- qparams : JitQuantizationParameters :: new ( q. scale . elem ( ) , None , device) ,
73
- }
74
- }
75
53
} ,
76
54
_ => panic ! (
77
55
"Invalid dtype (expected DType::QFloat, got {:?})" ,
@@ -119,35 +97,18 @@ where
119
97
120
98
async fn q_into_data ( tensor : QuantizedTensor < Self > ) -> TensorData {
121
99
let strategy = tensor. strategy ( ) ;
122
- let numel = tensor. qtensor . shape . num_elements ( ) ;
123
100
let qtensor = kernel:: into_contiguous ( tensor. qtensor ) ;
124
101
125
102
let bytes = qtensor. client . read_async ( qtensor. handle . binding ( ) ) . await ;
126
103
127
- // Convert packed bytes to quantized dtype (TensorData can be used with other backends,
128
- // which don't have the prior knowledge of this packed representation)
104
+ // TensorData keeps quantized values packed into 32-bit unsigned integers so we can
105
+ // keep the current representation, just cast the bytes as u32.
129
106
match & tensor. scheme {
130
107
QuantizationScheme :: PerTensorAffine ( dtype)
131
108
| QuantizationScheme :: PerTensorSymmetric ( dtype) => match dtype {
132
- QuantizationType :: QInt8 => TensorData :: quantized (
133
- u32:: from_bytes ( & bytes)
134
- . iter ( )
135
- . enumerate ( )
136
- . flat_map ( |( i, packed) | {
137
- // A single u32 could contain less than four 8-bit values...
138
- let n = core:: cmp:: min ( 4 , numel - i * 4 ) ;
139
- // Extract each 8-bit segment from u32 and cast back to i8
140
- // Same as doing this (when 4 values are fully packed):
141
- // let a = ((packed >> 24) & 0xFF) as i8;
142
- // let b = ((packed >> 16) & 0xFF) as i8;
143
- // let c = ((packed >> 8) & 0xFF) as i8;
144
- // let d = (packed & 0xFF) as i8;
145
- ( 0 ..n) . map ( move |i| ( packed >> ( ( 3 - i) * 8 ) & 0xFF ) as i8 )
146
- } )
147
- . collect ( ) ,
148
- qtensor. shape ,
149
- strategy,
150
- ) ,
109
+ QuantizationType :: QInt8 => {
110
+ TensorData :: quantized ( u32:: from_bytes ( & bytes) . to_vec ( ) , qtensor. shape , strategy)
111
+ }
151
112
} ,
152
113
}
153
114
}
0 commit comments