Skip to content

Commit

Permalink
Feat/shared sum (#2737)
Browse files Browse the repository at this point in the history
* bump cubecl version

* bump cubecl version

* import new specialized sum reduction from cubecl

* commit missing autotune key

* improve chained reduction

* fix reduce shape issue

* fix typos and dead code
  • Loading branch information
maxtremblay authored Jan 24, 2025
1 parent e586b17 commit 7ddb5af
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 35 deletions.
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
86 changes: 75 additions & 11 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,94 @@
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
use super::{autotune_reduce, autotune_sum};
use crate::{
element::JitElement,
ops::{from_data, numeric::empty_device},
tensor::JitTensor,
JitRuntime,
};
use burn_tensor::{Shape, TensorData};
pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
use cubecl::reduce::shared_sum;

use super::autotune_reduce;
/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return
/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
///
/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
///
/// Return an error if the `client` doesn't support atomic add for the type `E`.
pub fn sum<Run: JitRuntime, E: JitElement>(
tensor: JitTensor<Run>,
cube_count: SumStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
let client = tensor.client.clone();
let device = tensor.device.clone();

pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
match cube_count {
SumStrategy::OneShot(cube_count) => {
let output = shared_sum::<Run, E>(&client, tensor.as_handle_ref(), cube_count)?;
Ok(from_data::<Run, E>(
TensorData::new(vec![output], vec![1]),
&device,
))
}
SumStrategy::Chained(strategy) => reduce::<Run, E, E, Sum>(tensor, strategy),
SumStrategy::Autotune => Ok(autotune_sum::<Run, E>(&client, tensor)),
}
}

/// Select a strategy to perform a sum.
pub enum SumStrategy {
/// Run a single kernel with many cubes working in parallel to sum all elements.
/// The provided value is the number of elements summed per unit (up-to-rounding )
OneShot(u32),
/// Use multiple kernels
Chained(ReduceStrategy),
/// Use autotune to find the best cube count given the hardware and the input.
#[cfg(feature = "autotune")]
Autotune,
}

impl Default for SumStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;

#[cfg(not(feature = "autotune"))]
return Self::Static(4);
}
}

/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
mut input: JitTensor<Run>,
mut tensor: JitTensor<Run>,
strategy: ReduceStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
input.shape = input.shape.flatten();
input.strides = vec![1];
reduce_dim::<Run, In, Out, Rd>(input, 0, strategy)
// In practice, it looks like starting by the axis with the smallest shape
// and going in increasing order lead to the fastest calculation.
let sorted_axis = argsort(&tensor.shape.dims);
for axis in sorted_axis {
tensor = reduce_dim::<Run, In, Out, Rd>(tensor, axis, strategy)?;
}
// reshape to scalar tensor
tensor.shape = Shape::new([1]);
tensor.strides = vec![1];
Ok(tensor)
}

fn argsort(shape: &[usize]) -> Vec<usize> {
let mut indices = (0..shape.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &shape[i]);
indices
}

/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
Expand Down Expand Up @@ -58,7 +121,8 @@ pub fn reduce_dim<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::
),
#[cfg(feature = "autotune")]
ReduceStrategy::Autotune => {
autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim)
autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim);
Ok(())
}
};
result.map(|_| output)
Expand Down
93 changes: 88 additions & 5 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::{
kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor,
JitAutotuneKey, JitElement, JitRuntime, JitTuneId,
};
use reduce_ops::*;

/// Executes autotune on reduce operations.
pub fn autotune_reduce<
Expand All @@ -25,7 +24,9 @@ pub fn autotune_reduce<
input: JitTensor<Run>,
output: JitTensor<Run>,
dim: usize,
) -> Result<(), cubecl::reduce::ReduceError> {
) {
use reduce_ops::*;

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

let tunables = TunableSet::new(create_key::<Run>, reduce_input_gen::<Run, In, Out>)
Expand All @@ -40,12 +41,10 @@ pub fn autotune_reduce<
&tunables,
(input, output, dim),
);

Ok(())
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of redue versions
/// Autotune key representative of reduce versions
pub struct ReduceAutotuneKey {
dtype: burn_tensor::DType,
#[autotune(anchor)]
Expand Down Expand Up @@ -207,3 +206,87 @@ mod reduce_ops {
.map_err(|e| format!("{e}"))
}
}

/// Executes autotune on reduce operations.
pub fn autotune_sum<Run: JitRuntime, E: JitElement>(
client: &ComputeClient<Run::Server, Run::Channel>,
input: JitTensor<Run>,
) -> JitTensor<Run> {
use sum_ops::*;

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

let tunables = TunableSet::new(create_key_sum::<Run>, sum_input_gen::<Run, E>)
.with_tunable(sum_one_shot::<Run, E, 1>)
.with_tunable(sum_one_shot::<Run, E, 2>)
.with_tunable(sum_one_shot::<Run, E, 4>)
.with_tunable(sum_one_shot::<Run, E, 8>)
.with_tunable(sum_one_shot::<Run, E, 16>)
.with_tunable(sum_one_shot::<Run, E, 32>)
.with_tunable(sum_one_shot::<Run, E, 64>)
.with_tunable(sum_chained::<Run, E>);

TUNER.execute(
&JitTuneId::new::<Run>(&input.device),
client,
&tunables,
input,
)
}

pub(crate) fn create_key_sum<Run: JitRuntime>(input: &JitTensor<Run>) -> JitAutotuneKey {
JitAutotuneKey::Sum(SumAutotuneKey::generate(input))
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of sum versions
pub struct SumAutotuneKey {
dtype: burn_tensor::DType,
#[autotune(anchor)]
length: usize,
}

impl SumAutotuneKey {
pub(crate) fn generate<Run: JitRuntime>(input: &JitTensor<Run>) -> Self {
let dtype = input.dtype;
let length = input.shape.num_elements();
Self { dtype, length }
}
}
mod sum_ops {
#![allow(missing_docs)]

use burn_tensor::TensorData;
use cubecl::reduce::instructions::Sum;

use crate::ops::from_data;

use super::*;

pub(crate) fn sum_input_gen<Run: JitRuntime, E: JitElement>(
_key: &JitAutotuneKey,
input: &JitTensor<Run>,
) -> JitTensor<Run> {
let random_bounds: (E, E) = ((-10.0_f32).elem::<E>(), (10.0_f32).elem::<E>());
random_like_uniform(input, random_bounds.0, random_bounds.1)
}

pub(crate) fn sum_one_shot<Run: JitRuntime, E: JitElement, const C: u32>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
let device = input.device.clone();
cubecl::reduce::shared_sum::<Run, E>(&input.client, input.as_handle_ref(), C)
.map(|output| from_data::<Run, E>(TensorData::new(vec![output], vec![1]), &device))
.map_err(|e| e.to_string())
}

pub(crate) fn sum_chained<Run: JitRuntime, E: JitElement>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
crate::kernel::reduce::reduce::<Run, E, E, Sum>(
input,
crate::kernel::reduce::ReduceStrategy::Autotune,
)
.map_err(|e| e.to_string())
}
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ where
execute_with_dtype!(
float(tensor.dtype),
E,
reduce::reduce::<R, E, E, reduce::Sum>(tensor, Default::default()).unwrap()
reduce::sum::<R, E>(tensor, Default::default()).unwrap()
)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ where
}

fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
reduce::reduce::<R, I, I, reduce::Sum>(tensor, Default::default()).unwrap()
reduce::sum::<R, I>(tensor, Default::default()).unwrap()
}

fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
Expand Down
Loading

0 comments on commit 7ddb5af

Please sign in to comment.