Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 4, 2025
1 parent 4ce7993 commit 12de765
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
15 changes: 0 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-jit/src/fusion/on_write/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ impl FuseOnWriteBuilder {
pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg {
if self.current_output_shape.is_empty() {
self.current_output_shape = tensor.shape.clone();
} else if self.current_output_shape.iter().sum::<usize>() < tensor.shape.iter().sum() {
// The larguest shape win.
self.current_output_shape = tensor.shape.clone();
}

self.builder.builder.output_unhandled(tensor)
Expand Down
17 changes: 12 additions & 5 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ pub fn sum<Run: JitRuntime, E: JitElement>(

match cube_count {
SumStrategy::OneShot(cube_count) => {
let output = shared_sum::<Run, E>(&client, tensor.as_handle_ref(), cube_count)?;
Ok(from_data::<Run, E>(
TensorData::new(vec![output], vec![1]),
&device,
))
let handle = client.empty(E::size().unwrap());
let output =
JitTensor::new_contiguous(client.clone(), device, [1].into(), handle, E::dtype());

shared_sum::<Run, E>(
&client,
tensor.as_handle_ref(),
output.as_handle_ref(),
cube_count,
)?;

Ok(output)
}
SumStrategy::Chained(strategy) => reduce::<Run, E, E, Sum>(tensor, strategy),
#[cfg(feature = "autotune")]
Expand Down
16 changes: 13 additions & 3 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,20 @@ mod sum_ops {
pub(crate) fn sum_one_shot<Run: JitRuntime, E: JitElement, const C: u32>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
let client = input.client.clone();
let device = input.device.clone();
cubecl::reduce::shared_sum::<Run, E>(&input.client, input.as_handle_ref(), C)
.map(|output| from_data::<Run, E>(TensorData::new(vec![output], vec![1]), &device))
.map_err(|e| e.to_string())
let handle = client.empty(E::size().unwrap());
let output = JitTensor::new_contiguous(client, device, [1].into(), handle, E::dtype());

cubecl::reduce::shared_sum::<Run, E>(
&input.client,
input.as_handle_ref(),
output.as_handle_ref(),
C,
)
.map_err(|e| e.to_string())?;

Ok(output)
}

#[cfg(feature = "autotune")]
Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/examples/ag-news-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ mod wgpu {
#[cfg(feature = "vulkan")]
mod vulkan {
use crate::{launch, ElemType};
use burn::backend::{Vulkan, Autodiff};
use burn::backend::{Autodiff, Vulkan};

pub fn run() {
launch::<Autodiff<Vulkan<ElemType, i32>>>(vec![Default::default()]);
Expand Down

0 comments on commit 12de765

Please sign in to comment.