Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/shared sum #2737

Merged
merged 7 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions backend-comparison/benches/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct ReduceBenchmark<B: Backend> {
impl<B: Backend> ReduceBenchmark<B> {
pub fn new(instruction: Instruction, device: B::Device) -> Self {
let shape = Shape::new([4096, 512, 64]);
// let shape = Shape::new([128, 128, 64]);
let tensor = Tensor::random(shape.clone(), Distribution::Default, &device);
Self {
instruction,
Expand Down
84 changes: 73 additions & 11 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,92 @@
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
use super::{autotune_reduce, autotune_sum};
use crate::{
element::JitElement,
ops::{from_data, numeric::empty_device},
tensor::JitTensor,
JitRuntime,
};
use burn_tensor::{Shape, TensorData};
pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
use cubecl::reduce::shared_sum;

use super::autotune_reduce;
/// Specialize reduce function to computhe the sum of all elements of the `input` tensor and return
/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
///
/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
///
/// Return an error if the `client` doesn't support atomic add for the type `E`.
pub fn sum<Run: JitRuntime, E: JitElement>(
tensor: JitTensor<Run>,
cube_count: SumStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
let client = tensor.client.clone();
let device = tensor.device.clone();
let shape_out: Shape = vec![1_usize; tensor.shape.num_dims()].into();

pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
match cube_count {
SumStrategy::OneShot(cube_count) => {
let output = shared_sum::<Run, E>(&client, tensor.as_handle_ref(), cube_count)?;
Ok(from_data::<Run, E>(
TensorData::new(vec![output], shape_out),
&device,
))
}
SumStrategy::Chained(strategy) => reduce::<Run, E, E, Sum>(tensor, strategy),
SumStrategy::Autotune => Ok(autotune_sum::<Run, E>(&client, tensor)),
}
}

/// Select a strategy to perform a sum.
pub enum SumStrategy {
/// Run a single kernel with many cubes working in parallel to sum all elements.
/// The provided value is the number of elements summed per unit (up-to-rounding )
OneShot(u32),
/// Use multiple kernels
Chained(ReduceStrategy),
/// Use autotune to find the best cube count given the hardware and the input.
#[cfg(feature = "autotune")]
Autotune,
}

impl Default for SumStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;

#[cfg(not(feature = "autotune"))]
return Self::Static(4);
}
}

/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
mut input: JitTensor<Run>,
mut tensor: JitTensor<Run>,
strategy: ReduceStrategy,
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
input.shape = input.shape.flatten();
input.strides = vec![1];
reduce_dim::<Run, In, Out, Rd>(input, 0, strategy)
// In practice, it looks like starting by the axis with the smallest shape
// and going in increasing order lead to the fastest calculation.
let sorted_axis = argsort(&tensor.shape.dims);
for axis in sorted_axis {
tensor = reduce_dim::<Run, In, Out, Rd>(tensor, axis, strategy)?;
}
Ok(tensor)
}

fn argsort(shape: &[usize]) -> Vec<usize> {
let mut indices = (0..shape.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &shape[i]);
indices
}

/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
///
/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
///
/// If there is no error, the output is a tensor with decreasing strides
/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
Expand Down Expand Up @@ -58,7 +119,8 @@ pub fn reduce_dim<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::
),
#[cfg(feature = "autotune")]
ReduceStrategy::Autotune => {
autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim)
autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim);
Ok(())
}
};
result.map(|_| output)
Expand Down
94 changes: 89 additions & 5 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::{
kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor,
JitAutotuneKey, JitElement, JitRuntime, JitTuneId,
};
use reduce_ops::*;

/// Executes autotune on reduce operations.
pub fn autotune_reduce<
Expand All @@ -25,7 +24,9 @@ pub fn autotune_reduce<
input: JitTensor<Run>,
output: JitTensor<Run>,
dim: usize,
) -> Result<(), cubecl::reduce::ReduceError> {
) {
use reduce_ops::*;

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

let tunables = TunableSet::new(create_key::<Run>, reduce_input_gen::<Run, In, Out>)
Expand All @@ -40,12 +41,10 @@ pub fn autotune_reduce<
&tunables,
(input, output, dim),
);

Ok(())
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of redue versions
/// Autotune key representative of reduce versions
pub struct ReduceAutotuneKey {
dtype: burn_tensor::DType,
#[autotune(anchor)]
Expand Down Expand Up @@ -207,3 +206,88 @@ mod reduce_ops {
.map_err(|e| format!("{e}"))
}
}

/// Executes autotune on reduce operations.
pub fn autotune_sum<Run: JitRuntime, E: JitElement>(
client: &ComputeClient<Run::Server, Run::Channel>,
input: JitTensor<Run>,
) -> JitTensor<Run> {
use sum_ops::*;

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

let tunables = TunableSet::new(create_key_sum::<Run>, sum_input_gen::<Run, E>)
.with_tunable(sum_one_shot::<Run, E, 1>)
.with_tunable(sum_one_shot::<Run, E, 2>)
.with_tunable(sum_one_shot::<Run, E, 4>)
.with_tunable(sum_one_shot::<Run, E, 8>)
.with_tunable(sum_one_shot::<Run, E, 16>)
.with_tunable(sum_one_shot::<Run, E, 32>)
.with_tunable(sum_one_shot::<Run, E, 64>)
.with_tunable(sum_chained::<Run, E>);

TUNER.execute(
&JitTuneId::new::<Run>(&input.device),
client,
&tunables,
input,
)
}

pub(crate) fn create_key_sum<Run: JitRuntime>(input: &JitTensor<Run>) -> JitAutotuneKey {
JitAutotuneKey::Sum(SumAutotuneKey::generate(input))
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of sum versions
pub struct SumAutotuneKey {
dtype: burn_tensor::DType,
#[autotune(anchor)]
length: usize,
}

impl SumAutotuneKey {
pub(crate) fn generate<Run: JitRuntime>(input: &JitTensor<Run>) -> Self {
let dtype = input.dtype;
let length = input.shape.num_elements();
Self { dtype, length }
}
}
mod sum_ops {
#![allow(missing_docs)]

use burn_tensor::{Shape, TensorData};
use cubecl::reduce::instructions::Sum;

use crate::ops::from_data;

use super::*;

pub(crate) fn sum_input_gen<Run: JitRuntime, E: JitElement>(
_key: &JitAutotuneKey,
input: &JitTensor<Run>,
) -> JitTensor<Run> {
let random_bounds: (E, E) = ((-10.0_f32).elem::<E>(), (10.0_f32).elem::<E>());
random_like_uniform(input, random_bounds.0, random_bounds.1)
}

pub(crate) fn sum_one_shot<Run: JitRuntime, E: JitElement, const C: u32>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
let device = input.device.clone();
let shape_out: Shape = vec![1_usize; input.shape.num_dims()].into();
cubecl::reduce::shared_sum::<Run, E>(&input.client, input.as_handle_ref(), C)
.map(|output| from_data::<Run, E>(TensorData::new(vec![output], shape_out), &device))
.map_err(|e| e.to_string())
}

pub(crate) fn sum_chained<Run: JitRuntime, E: JitElement>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
crate::kernel::reduce::reduce::<Run, E, E, Sum>(
input,
crate::kernel::reduce::ReduceStrategy::Autotune,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure i understand, calling sum_chained during the autotune of sum-full will trigger the autotune of sum-dim, which may trigger several autotunes during the chain if the various dims are not similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

)
.map_err(|e| e.to_string())
}
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ where
execute_with_dtype!(
float(tensor.dtype),
E,
reduce::reduce::<R, E, E, reduce::Sum>(tensor, Default::default()).unwrap()
reduce::sum::<R, E>(tensor, Default::default()).unwrap()
)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ where
}

fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
reduce::reduce::<R, I, I, reduce::Sum>(tensor, Default::default()).unwrap()
reduce::sum::<R, I>(tensor, Default::default()).unwrap()
}

fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
Expand Down
Loading
Loading