Skip to content

Commit 95593fc

Browse files
authored
Fix parallel macro + CI (#2678)
* Fix rayon issues * Fix typos * Fix for_each no std * Fix clippy
1 parent da8de56 commit 95593fc

File tree

7 files changed

+119
-60
lines changed

7 files changed

+119
-60
lines changed

backend-comparison/src/burnbenchapp/auth/base.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn verify_tokens(tokens: &Tokens) -> bool {
133133
)
134134
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
135135
.send();
136-
response.map_or(false, |resp| resp.status().is_success())
136+
response.is_ok_and(|resp| resp.status().is_success())
137137
}
138138

139139
fn refresh_tokens(tokens: &Tokens) -> Option<Tokens> {

crates/burn-common/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ pub mod id;
1111

1212
pub use cubecl_common::*;
1313

14+
#[cfg(feature = "rayon")]
15+
pub use rayon;
16+
1417
extern crate alloc;
1518

1619
/// Network utilities.

crates/burn-common/src/parallel.rs

+57-18
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,90 @@
11
/// Macro for running a function in parallel.
2+
#[cfg(feature = "rayon")]
23
#[macro_export(local_inner_macros)]
34
macro_rules! run_par {
45
(
56
$func:expr
67
) => {{
7-
#[cfg(feature = "rayon")]
8-
use rayon::prelude::*;
8+
use $crate::rayon::prelude::*;
99

10-
#[cfg(feature = "rayon")]
1110
#[allow(clippy::redundant_closure_call)]
12-
let output = rayon::scope(|_| $func());
11+
$crate::rayon::scope(|_| $func())
12+
}};
13+
}
1314

14-
#[cfg(not(feature = "rayon"))]
15-
let output = $func();
15+
/// Macro for running a function in parallel.
16+
#[cfg(not(feature = "rayon"))]
17+
#[macro_export(local_inner_macros)]
18+
macro_rules! run_par {
19+
(
20+
$func:expr
21+
) => {{
22+
$func()
23+
}};
24+
}
1625

17-
output
26+
/// Macro for iterating in parallel.
27+
#[cfg(not(feature = "rayon"))]
28+
#[macro_export(local_inner_macros)]
29+
macro_rules! iter_par {
30+
(
31+
$iter:expr
32+
) => {{
33+
$iter
1834
}};
1935
}
2036

2137
/// Macro for iterating in parallel.
38+
#[cfg(feature = "rayon")]
2239
#[macro_export(local_inner_macros)]
2340
macro_rules! iter_par {
2441
(
2542
$iter:expr
2643
) => {{
27-
#[cfg(feature = "rayon")]
28-
let output = $iter.into_par_iter();
44+
$iter.into_par_iter()
45+
}};
46+
}
2947

30-
#[cfg(not(feature = "rayon"))]
31-
let output = $iter;
48+
/// Macro for iterating in parallel.
49+
#[cfg(feature = "rayon")]
50+
#[macro_export(local_inner_macros)]
51+
macro_rules! iter_slice_par {
52+
(
53+
$slice:expr
54+
) => {{
55+
$slice.into_par_iter()
56+
}};
57+
}
3258

33-
output
59+
/// Macro for iterating in parallel.
60+
#[cfg(not(feature = "rayon"))]
61+
#[macro_export(local_inner_macros)]
62+
macro_rules! iter_slice_par {
63+
(
64+
$slice:expr
65+
) => {{
66+
$slice.iter()
3467
}};
3568
}
3669

3770
/// Macro for iterating over a range in parallel.
71+
#[cfg(feature = "rayon")]
3872
#[macro_export(local_inner_macros)]
3973
macro_rules! iter_range_par {
4074
(
4175
$start:expr, $end:expr
4276
) => {{
43-
#[cfg(feature = "rayon")]
44-
let output = ($start..$end).into_par_iter();
45-
46-
#[cfg(not(feature = "rayon"))]
47-
let output = ($start..$end);
77+
($start..$end).into_par_iter()
78+
}};
79+
}
4880

49-
output
81+
/// Macro for iterating over a range in parallel.
82+
#[cfg(not(feature = "rayon"))]
83+
#[macro_export(local_inner_macros)]
84+
macro_rules! iter_range_par {
85+
(
86+
$start:expr, $end:expr
87+
) => {{
88+
($start..$end)
5089
}};
5190
}

crates/burn-jit/src/fusion/on_write/ir.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl<R: Runtime> GlobalArgsLaunch<'_, R> {
154154
}
155155
}
156156

157-
/// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg).
157+
/// Resolve the [argument](Arg) to a [tensor argument](TensorArg).
158158
///
159159
/// # Panics
160160
///

crates/burn-ndarray/src/ops/deform_conv.rs

+29-23
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use burn_tensor::{
66
use core::ops::AddAssign;
77
use ndarray::{
88
s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim,
9-
Ix4,
9+
Ix4, Zip,
1010
};
1111
#[cfg(not(feature = "std"))]
1212
use num_traits::Float;
@@ -593,31 +593,37 @@ pub mod backward {
593593
AtomicF32::new(0.0)
594594
});
595595

596+
let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| {
597+
let group = in_channel / channels_per_offset_group;
598+
let offset = offset.slice(s![batch, .., out_y, out_x]);
599+
let offset = offset
600+
.to_shape((offs_groups, kernel_h, kernel_w, 2))
601+
.unwrap();
602+
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
603+
let offset = [offset[0], offset[1]];
604+
let mask = mask
605+
.as_ref()
606+
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
607+
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
608+
- F::from_elem(args.padding[0])
609+
+ offset[0];
610+
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
611+
- F::from_elem(args.padding[1])
612+
+ offset[1];
613+
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
614+
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
615+
};
616+
617+
// `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise
618+
#[cfg(feature = "std")]
596619
run_par!(|| {
597-
iter_par!(columns.indexed_iter()).for_each(
598-
|((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| {
599-
let group = in_channel / channels_per_offset_group;
600-
let offset = offset.slice(s![batch, .., out_y, out_x]);
601-
let offset = offset
602-
.to_shape((offs_groups, kernel_h, kernel_w, 2))
603-
.unwrap();
604-
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
605-
let offset = [offset[0], offset[1]];
606-
let mask = mask
607-
.as_ref()
608-
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
609-
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
610-
- F::from_elem(args.padding[0])
611-
+ offset[0];
612-
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
613-
- F::from_elem(args.padding[1])
614-
+ offset[1];
615-
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
616-
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
617-
},
618-
)
620+
iter_par!(Zip::indexed(columns))
621+
.for_each(|(args0, args1)| compute_for_each(args0, args1))
619622
});
620623

624+
#[cfg(not(feature = "std"))]
625+
run_par!(|| { iter_par!(Zip::indexed(columns).for_each(compute_for_each)) });
626+
621627
let grad_in: Array1<F> = grad_in
622628
.into_iter()
623629
.map(|it| F::from_elem(it.into_inner()))

crates/burn-tensor/src/tensor/quantization/bytes.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl QuantizedBytes {
100100

101101
/// Splits the quantized values of the tensor from the quantization parameters.
102102
///
103-
/// Returns the packed values and a newly allocated vector containining the quantization parameters.
103+
/// Returns the packed values and a newly allocated vector containing the quantization parameters.
104104
fn split_values_off(self) -> (Vec<u32>, Vec<u32>) {
105105
// The bytes can be created either from packed u32 or existing bytes with the same representation.
106106
let mut values = match self.bytes.align() {

crates/burn-tensor/src/tensor/quantization/strategy.rs

+27-16
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use core::{
44
};
55

66
use alloc::vec::Vec;
7-
use burn_common::{iter_par, run_par};
7+
use burn_common::{iter_slice_par, run_par};
88
use num_traits::{Float, PrimInt};
99
use serde::{Deserialize, Serialize};
1010

@@ -35,7 +35,7 @@ impl QuantizationStrategy {
3535

3636
/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
3737
/// data type `Q` and vice-versa.
38-
pub trait Quantization<E: Float, Q: PrimInt> {
38+
pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
3939
/// Create a new quantization scheme for an input range `[alpha, beta]`.
4040
fn new(alpha: E, beta: E) -> Self;
4141
/// Convert the values to a lower precision data type.
@@ -48,7 +48,7 @@ pub trait Quantization<E: Float, Q: PrimInt> {
4848
///
4949
/// Note that the accumulation type `A` should have a bigger range than quantized type `Q`.
5050
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
51-
pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
51+
pub struct AffineQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> {
5252
/// The scaling factor.
5353
pub scale: E,
5454
/// The zero-point offset.
@@ -66,7 +66,7 @@ fn valid_scale<E: Float>(mut scale: E) -> E {
6666
scale
6767
}
6868

69-
impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
69+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> AffineQuantization<E, Q, A> {
7070
/// Initialize an affine quantization scheme with the given parameters.
7171
pub fn init(scale: E, offset: Q) -> Self {
7272
Self {
@@ -77,7 +77,9 @@ impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
7777
}
7878
}
7979

80-
impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization<E, Q, A> {
80+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt + Send + Sync> Quantization<E, Q>
81+
for AffineQuantization<E, Q, A>
82+
{
8183
fn new(alpha: E, beta: E) -> Self {
8284
// Q range `[a, b]`
8385
let a = E::from(Q::min_value()).unwrap();
@@ -107,7 +109,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
107109
// x_q = clamp(round(x / scale + offset), a, b)
108110
let z = E::from(self.offset).unwrap();
109111
run_par!(|| {
110-
iter_par!(values.iter())
112+
iter_slice_par!(values)
111113
.map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap())
112114
.collect()
113115
})
@@ -116,7 +118,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
116118
fn dequantize(&self, values: &[Q]) -> Vec<E> {
117119
// x = scale * (x_q - offset)
118120
run_par!(|| {
119-
iter_par!(values.iter())
121+
iter_slice_par!(values)
120122
.map(|x_q| {
121123
self.scale
122124
* (E::from(
@@ -133,14 +135,14 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
133135

134136
/// Symmetric quantization scheme.
135137
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136-
pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
138+
pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
137139
/// The scaling factor.
138140
pub scale: E,
139141
/// The quantized type.
140142
_q: PhantomData<Q>,
141143
}
142144

143-
impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
145+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> SymmetricQuantization<E, Q> {
144146
/// Initialize a symmetric quantization scheme with the given parameters.
145147
pub fn init(scale: E) -> Self {
146148
Self {
@@ -150,7 +152,9 @@ impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
150152
}
151153
}
152154

153-
impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
155+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Quantization<E, Q>
156+
for SymmetricQuantization<E, Q>
157+
{
154158
fn new(alpha: E, beta: E) -> Self {
155159
assert!(
156160
!Q::min_value().is_zero(),
@@ -214,7 +218,9 @@ fn canonicalize_signed_zero<T: Float>(x: T) -> T {
214218
x + T::zero()
215219
}
216220

217-
impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q, A> {
221+
impl<E: Float + Send + Sync, Q: PrimInt + Hash + Send + Sync, A: PrimInt> Hash
222+
for AffineQuantization<E, Q, A>
223+
{
218224
fn hash<H: Hasher>(&self, state: &mut H) {
219225
// Hash raw bits.
220226
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
@@ -223,29 +229,34 @@ impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q,
223229
}
224230
}
225231

226-
impl<E: Float, Q: PrimInt, A: PrimInt> PartialEq for AffineQuantization<E, Q, A> {
232+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> PartialEq
233+
for AffineQuantization<E, Q, A>
234+
{
227235
fn eq(&self, other: &Self) -> bool {
228236
self.scale == other.scale && self.offset == other.offset
229237
}
230238
}
231239

232-
impl<E: Float, Q: PrimInt, A: PrimInt> Eq for AffineQuantization<E, Q, A> {}
240+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> Eq
241+
for AffineQuantization<E, Q, A>
242+
{
243+
}
233244

234-
impl<E: Float, Q: PrimInt> Hash for SymmetricQuantization<E, Q> {
245+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Hash for SymmetricQuantization<E, Q> {
235246
fn hash<H: Hasher>(&self, state: &mut H) {
236247
// Hash raw bits.
237248
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
238249
bits.hash(state);
239250
}
240251
}
241252

242-
impl<E: Float, Q: PrimInt> PartialEq for SymmetricQuantization<E, Q> {
253+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> PartialEq for SymmetricQuantization<E, Q> {
243254
fn eq(&self, other: &Self) -> bool {
244255
self.scale == other.scale
245256
}
246257
}
247258

248-
impl<E: Float, Q: PrimInt> Eq for SymmetricQuantization<E, Q> {}
259+
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Eq for SymmetricQuantization<E, Q> {}
249260

250261
#[cfg(test)]
251262
mod tests {

0 commit comments

Comments
 (0)