@@ -4,7 +4,7 @@ use core::{
4
4
} ;
5
5
6
6
use alloc:: vec:: Vec ;
7
- use burn_common:: { iter_par , run_par} ;
7
+ use burn_common:: { iter_slice_par , run_par} ;
8
8
use num_traits:: { Float , PrimInt } ;
9
9
use serde:: { Deserialize , Serialize } ;
10
10
@@ -35,7 +35,7 @@ impl QuantizationStrategy {
35
35
36
36
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
37
37
/// data type `Q` and vice-versa.
38
- pub trait Quantization < E : Float , Q : PrimInt > {
38
+ pub trait Quantization < E : Float + Send + Sync , Q : PrimInt + Send + Sync > {
39
39
/// Create a new quantization scheme for an input range `[alpha, beta]`.
40
40
fn new ( alpha : E , beta : E ) -> Self ;
41
41
/// Convert the values to a lower precision data type.
@@ -48,7 +48,7 @@ pub trait Quantization<E: Float, Q: PrimInt> {
48
48
///
49
49
/// Note that the accumulation type `A` should have a bigger range than quantized type `Q`.
50
50
#[ derive( Debug , Clone , Copy , Serialize , Deserialize ) ]
51
- pub struct AffineQuantization < E : Float , Q : PrimInt , A : PrimInt > {
51
+ pub struct AffineQuantization < E : Float + Send + Sync , Q : PrimInt + Send + Sync , A : PrimInt > {
52
52
/// The scaling factor.
53
53
pub scale : E ,
54
54
/// The zero-point offset.
@@ -66,7 +66,7 @@ fn valid_scale<E: Float>(mut scale: E) -> E {
66
66
scale
67
67
}
68
68
69
- impl < E : Float , Q : PrimInt , A : PrimInt > AffineQuantization < E , Q , A > {
69
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync , A : PrimInt > AffineQuantization < E , Q , A > {
70
70
/// Initialize an affine quantization scheme with the given parameters.
71
71
pub fn init ( scale : E , offset : Q ) -> Self {
72
72
Self {
@@ -77,7 +77,9 @@ impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
77
77
}
78
78
}
79
79
80
- impl < E : Float , Q : PrimInt , A : PrimInt > Quantization < E , Q > for AffineQuantization < E , Q , A > {
80
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync , A : PrimInt + Send + Sync > Quantization < E , Q >
81
+ for AffineQuantization < E , Q , A >
82
+ {
81
83
fn new ( alpha : E , beta : E ) -> Self {
82
84
// Q range `[a, b]`
83
85
let a = E :: from ( Q :: min_value ( ) ) . unwrap ( ) ;
@@ -107,7 +109,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
107
109
// x_q = clamp(round(x / scale + offset), a, b)
108
110
let z = E :: from ( self . offset ) . unwrap ( ) ;
109
111
run_par ! ( || {
110
- iter_par !( values. iter ( ) )
112
+ iter_slice_par !( values)
111
113
. map( |x| Q :: from( x. div( self . scale) . add( z) . round( ) . clamp( a, b) ) . unwrap( ) )
112
114
. collect( )
113
115
} )
@@ -116,7 +118,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
116
118
fn dequantize ( & self , values : & [ Q ] ) -> Vec < E > {
117
119
// x = scale * (x_q - offset)
118
120
run_par ! ( || {
119
- iter_par !( values. iter ( ) )
121
+ iter_slice_par !( values)
120
122
. map( |x_q| {
121
123
self . scale
122
124
* ( E :: from(
@@ -133,14 +135,14 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
133
135
134
136
/// Symmetric quantization scheme.
135
137
#[ derive( Debug , Clone , Copy , Serialize , Deserialize ) ]
136
- pub struct SymmetricQuantization < E : Float , Q : PrimInt > {
138
+ pub struct SymmetricQuantization < E : Float + Send + Sync , Q : PrimInt + Send + Sync > {
137
139
/// The scaling factor.
138
140
pub scale : E ,
139
141
/// The quantized type.
140
142
_q : PhantomData < Q > ,
141
143
}
142
144
143
- impl < E : Float , Q : PrimInt > SymmetricQuantization < E , Q > {
145
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync > SymmetricQuantization < E , Q > {
144
146
/// Initialize a symmetric quantization scheme with the given parameters.
145
147
pub fn init ( scale : E ) -> Self {
146
148
Self {
@@ -150,7 +152,9 @@ impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
150
152
}
151
153
}
152
154
153
- impl < E : Float , Q : PrimInt > Quantization < E , Q > for SymmetricQuantization < E , Q > {
155
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync > Quantization < E , Q >
156
+ for SymmetricQuantization < E , Q >
157
+ {
154
158
fn new ( alpha : E , beta : E ) -> Self {
155
159
assert ! (
156
160
!Q :: min_value( ) . is_zero( ) ,
@@ -214,7 +218,9 @@ fn canonicalize_signed_zero<T: Float>(x: T) -> T {
214
218
x + T :: zero ( )
215
219
}
216
220
217
- impl < E : Float , Q : PrimInt + Hash , A : PrimInt > Hash for AffineQuantization < E , Q , A > {
221
+ impl < E : Float + Send + Sync , Q : PrimInt + Hash + Send + Sync , A : PrimInt > Hash
222
+ for AffineQuantization < E , Q , A >
223
+ {
218
224
fn hash < H : Hasher > ( & self , state : & mut H ) {
219
225
// Hash raw bits.
220
226
let bits = raw_double_bits ( & canonicalize_signed_zero ( self . scale ) ) ;
@@ -223,29 +229,34 @@ impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q,
223
229
}
224
230
}
225
231
226
- impl < E : Float , Q : PrimInt , A : PrimInt > PartialEq for AffineQuantization < E , Q , A > {
232
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync , A : PrimInt > PartialEq
233
+ for AffineQuantization < E , Q , A >
234
+ {
227
235
fn eq ( & self , other : & Self ) -> bool {
228
236
self . scale == other. scale && self . offset == other. offset
229
237
}
230
238
}
231
239
232
- impl < E : Float , Q : PrimInt , A : PrimInt > Eq for AffineQuantization < E , Q , A > { }
240
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync , A : PrimInt > Eq
241
+ for AffineQuantization < E , Q , A >
242
+ {
243
+ }
233
244
234
- impl < E : Float , Q : PrimInt > Hash for SymmetricQuantization < E , Q > {
245
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync > Hash for SymmetricQuantization < E , Q > {
235
246
fn hash < H : Hasher > ( & self , state : & mut H ) {
236
247
// Hash raw bits.
237
248
let bits = raw_double_bits ( & canonicalize_signed_zero ( self . scale ) ) ;
238
249
bits. hash ( state) ;
239
250
}
240
251
}
241
252
242
- impl < E : Float , Q : PrimInt > PartialEq for SymmetricQuantization < E , Q > {
253
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync > PartialEq for SymmetricQuantization < E , Q > {
243
254
fn eq ( & self , other : & Self ) -> bool {
244
255
self . scale == other. scale
245
256
}
246
257
}
247
258
248
- impl < E : Float , Q : PrimInt > Eq for SymmetricQuantization < E , Q > { }
259
+ impl < E : Float + Send + Sync , Q : PrimInt + Send + Sync > Eq for SymmetricQuantization < E , Q > { }
249
260
250
261
#[ cfg( test) ]
251
262
mod tests {
0 commit comments