Skip to content

Commit

Permalink
improve chained reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay committed Jan 23, 2025
1 parent 4221820 commit c8e1655
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion backend-comparison/benches/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<B: Backend> Benchmark for ReduceBenchmark<B> {
self.tensor.clone().sum_dim(axis);
}
Instruction::Sum => {
self.tensor.clone().sum_dim(2).sum_dim(1).sum_dim(0);
self.tensor.clone().sum();
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,21 @@ pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::redu
mut tensor: JitTensor<Run>,
strategy: ReduceStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
for axis in 0..tensor.shape.num_dims() {
// In practice, it looks like starting by the axis with the smallest shape
// and going in increasing order lead to the fastest calculation.
let sorted_axis = argsort(&tensor.shape.dims);
for axis in sorted_axis {
tensor = reduce_dim::<Run, In, Out, Rd>(tensor, axis, strategy)?;
}
Ok(tensor)
}

fn argsort(shape: &[usize]) -> Vec<usize> {
let mut indices = (0..shape.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &shape[i]);
indices
}

/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
Expand Down

0 comments on commit c8e1655

Please sign in to comment.