Skip to content

Commit

Permalink
fix reduce shape issue
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay committed Jan 24, 2025
1 parent c8e1655 commit 16882c4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 4 additions & 2 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ pub fn sum<Run: JitRuntime, E: JitElement>(
) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
let client = tensor.client.clone();
let device = tensor.device.clone();
let shape_out: Shape = vec![1_usize; tensor.shape.num_dims()].into();

match cube_count {
SumStrategy::OneShot(cube_count) => {
let output = shared_sum::<Run, E>(&client, tensor.as_handle_ref(), cube_count)?;
Ok(from_data::<Run, E>(
TensorData::new(vec![output], shape_out),
TensorData::new(vec![output], vec![1]),
&device,
))
}
Expand Down Expand Up @@ -74,6 +73,9 @@ pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::redu
for axis in sorted_axis {
tensor = reduce_dim::<Run, In, Out, Rd>(tensor, axis, strategy)?;
}
// reshape to scalar tensor
tensor.shape = Shape::new([1]);
tensor.strides = vec![1];
Ok(tensor)
}

Expand Down
5 changes: 2 additions & 3 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl SumAutotuneKey {
mod sum_ops {
#![allow(missing_docs)]

use burn_tensor::{Shape, TensorData};
use burn_tensor::TensorData;
use cubecl::reduce::instructions::Sum;

use crate::ops::from_data;
Expand All @@ -275,9 +275,8 @@ mod sum_ops {
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
let device = input.device.clone();
let shape_out: Shape = vec![1_usize; input.shape.num_dims()].into();
cubecl::reduce::shared_sum::<Run, E>(&input.client, input.as_handle_ref(), C)
.map(|output| from_data::<Run, E>(TensorData::new(vec![output], shape_out), &device))
.map(|output| from_data::<Run, E>(TensorData::new(vec![output], vec![1]), &device))
.map_err(|e| e.to_string())
}

Expand Down

0 comments on commit 16882c4

Please sign in to comment.