From 04c5e74daac2adbfda4f5293746bc37ab6be56ea Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 28 Jan 2025 09:30:00 -0500 Subject: [PATCH] Fix types under autotune flag (#2750) --- crates/burn-core/Cargo.toml | 2 +- crates/burn-jit/src/fusion/matmul/optimization.rs | 2 +- crates/burn-jit/src/kernel/reduce/base.rs | 4 +++- crates/burn-jit/src/kernel/reduce/tune.rs | 2 ++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 5d926cd0b3..423dc784d8 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] ## Backend features accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] -autotune = ["burn-wgpu?/autotune"] +autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index 9a020df62c..804628613d 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -87,7 +87,7 @@ impl MatmulOptimization { fused_matmul_autotune::(self, context); #[cfg(not(feature = "autotune"))] - if self.execute_fused::(context).is_err() { + if self.execute_standard_fused::(context).is_err() { self.execute_fallback::(context); } } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 9ec677ee93..ccfcc3ef9e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "autotune")] use super::{autotune_reduce, autotune_sum}; use crate::{ element::JitElement, @@ -31,6 +32,7 @@ pub fn sum( )) } SumStrategy::Chained(strategy) => reduce::(tensor, strategy), + #[cfg(feature = "autotune")] SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), } } @@ -53,7 +55,7 @@ impl Default for SumStrategy { return Self::Autotune; #[cfg(not(feature = "autotune"))] - return Self::Static(4); + return Self::OneShot(4); } } diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index c397af1a04..cd5cd61157 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -208,6 +208,7 @@ mod reduce_ops { } /// Executes autotune on reduce operations. +#[cfg(feature = "autotune")] pub fn autotune_sum( client: &ComputeClient, input: JitTensor, @@ -280,6 +281,7 @@ mod sum_ops { .map_err(|e| e.to_string()) } + #[cfg(feature = "autotune")] pub(crate) fn sum_chained( input: JitTensor, ) -> Result, String> {