Skip to content

Commit

Permalink
Fix types under autotune flag (#2750)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 28, 2025
1 parent 2d9e9b9 commit 04c5e74
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"]

## Backend features
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
autotune = ["burn-wgpu?/autotune"]
autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"]
blas-netlib = ["burn-ndarray?/blas-netlib"]
metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<R: JitRuntime> MatmulOptimization<R> {
fused_matmul_autotune::<R, BT>(self, context);

#[cfg(not(feature = "autotune"))]
if self.execute_fused::<BT>(context).is_err() {
if self.execute_standard_fused::<BT>(context).is_err() {
self.execute_fallback::<BT>(context);
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[cfg(feature = "autotune")]
use super::{autotune_reduce, autotune_sum};
use crate::{
element::JitElement,
Expand Down Expand Up @@ -31,6 +32,7 @@ pub fn sum<Run: JitRuntime, E: JitElement>(
))
}
SumStrategy::Chained(strategy) => reduce::<Run, E, E, Sum>(tensor, strategy),
#[cfg(feature = "autotune")]
SumStrategy::Autotune => Ok(autotune_sum::<Run, E>(&client, tensor)),
}
}
Expand All @@ -53,7 +55,7 @@ impl Default for SumStrategy {
return Self::Autotune;

#[cfg(not(feature = "autotune"))]
return Self::Static(4);
return Self::OneShot(4);
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ mod reduce_ops {
}

/// Executes autotune on reduce operations.
#[cfg(feature = "autotune")]
pub fn autotune_sum<Run: JitRuntime, E: JitElement>(
client: &ComputeClient<Run::Server, Run::Channel>,
input: JitTensor<Run>,
Expand Down Expand Up @@ -280,6 +281,7 @@ mod sum_ops {
.map_err(|e| e.to_string())
}

#[cfg(feature = "autotune")]
pub(crate) fn sum_chained<Run: JitRuntime, E: JitElement>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
Expand Down

0 comments on commit 04c5e74

Please sign in to comment.