From ebd7649edc22e5304765f97725ceb386be20dd90 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:43:00 +0100 Subject: [PATCH] Migrate matmul autotune to macro and fix accelerated (#2584) * Migrate matmul autotune to macro and fix accelerated by checking for CMMA availability first * Set max anchor on batch --- .../burn-jit/src/kernel/matmul/tune/base.rs | 207 +++++++----------- crates/burn-jit/src/kernel/matmul/tune/key.rs | 84 ++++--- crates/burn-jit/src/kernel/matmul/tune/mod.rs | 2 +- crates/burn-jit/src/lib.rs | 2 - crates/burn-jit/src/tune.rs | 16 -- 5 files changed, 112 insertions(+), 199 deletions(-) delete mode 100644 crates/burn-jit/src/tune.rs diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 2b8050dc07..98951b8f89 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,9 +1,10 @@ -use core::marker::PhantomData; - use burn_tensor::{Element, ElementConversion}; use cubecl::{ + ir::{Elem, FloatKind}, linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy}, - tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}, + tune, + tune::{local_tuner, tune_with, LocalTuner}, + Feature, }; use crate::{ @@ -15,73 +16,45 @@ use crate::{ JitRuntime, JitTuneId, }; -use super::key::MatmulAutotuneKey; +use super::key::create_key; -/// Set of matmul implementations available for autotune -/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n -pub struct MatmulAutotuneOperationSet { +#[tune( + operations(matmul_tiling2d, matmul_accelerated, matmul_simple), + create_key = create_key::, + should_run = should_run +)] +fn matmul_ops( key: JitAutotuneKey, lhs: JitTensor, rhs: JitTensor, out: JitTensor, - _e: PhantomData, -} -impl MatmulAutotuneOperationSet { - fn new(lhs: JitTensor, rhs: JitTensor, out: JitTensor) -> Self { - Self { - key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())), - lhs, - rhs, - out, - _e: PhantomData, - } - } -} +) { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1); + let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1); -impl AutotuneOperationSet - for MatmulAutotuneOperationSet -{ - fn key(&self) -> JitAutotuneKey { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); - let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1); - let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1); + let out = empty_device::(out.client.clone(), out.device.clone(), out.shape.clone()); - let out = empty_device::( - self.out.client.clone(), - self.out.device.clone(), - self.out.shape.clone(), - ); - - vec![ - Box::new(MatmulTiling2d::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(MatmulAccelerated::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - Box::new(MatmulSimple::::new( - lhs.clone(), - rhs.clone(), - out.clone(), - )), - ] - } + tune_with!(lhs, rhs, out) +} - fn fastest(self: Box, fastest_index: usize) -> Box { - match fastest_index { - 0 => Box::new(MatmulTiling2d::::new(self.lhs, self.rhs, self.out)), - 1 => Box::new(MatmulAccelerated::::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(MatmulSimple::::new(self.lhs, self.rhs, self.out)), - _ => panic!("Fastest index is out of bound"), - } +fn should_run( + op: &MatmulOps, + _key: &JitAutotuneKey, + index: usize, +) -> bool { + match index { + // Accelerated + // TODO: Add way to query actual requirements from cubecl + 1 => op.lhs.client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }), + _ => true, } } @@ -100,82 +73,50 @@ pub fn matmul_autotune( TUNER.execute( &JitTuneId::new::(&lhs.device), &client, - Box::new(MatmulAutotuneOperationSet::::new( - lhs, - rhs, - output.clone(), - )), + Box::new(MatmulOps::::new(lhs, rhs, output.clone())), ); output } -macro_rules! matmul_tune_ops { - ($name:ident, $func:expr) => { - #[derive(new, Debug)] - pub(crate) struct $name { - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - _e: PhantomData, - } - - impl AutotuneOperation for $name { - fn execute(self: Box) { - #[allow(clippy::redundant_closure_call)] - $func(self.lhs, self.rhs, self.out); - } - - fn clone(&self) -> Box { - Box::new(Self { - lhs: self.lhs.clone(), - rhs: self.rhs.clone(), - out: self.out.clone(), - _e: self._e, - }) - } - } - }; +fn matmul_accelerated( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, +) { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Accelerated, + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); } -// Probably the fastest in the general case. -matmul_tune_ops!( - MatmulAccelerated, - |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { - cubecl::linalg::matmul::launch_ref::( - &Strategy::Accelerated, - &lhs.client, - &lhs.as_handle_ref(), - &rhs.as_handle_ref(), - &out.as_handle_ref(), - ); - } -); - -// Probably the fastest when tensor cores are not available. -matmul_tune_ops!( - MatmulTiling2d, - |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { - cubecl::linalg::matmul::launch_ref::( - &Strategy::Tiling2D(Tiling2dConfig::default()), - &lhs.client, - &lhs.as_handle_ref(), - &rhs.as_handle_ref(), - &out.as_handle_ref(), - ); - } -); +fn matmul_tiling2d( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, +) { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Tiling2D(Tiling2dConfig::default()), + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); +} -// Probably the fastest for small matrices. -matmul_tune_ops!( - MatmulSimple, - |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { - cubecl::linalg::matmul::launch_ref::( - &Strategy::Simple, - &lhs.client, - &lhs.as_handle_ref(), - &rhs.as_handle_ref(), - &out.as_handle_ref(), - ); - } -); +fn matmul_simple( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, +) { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Simple, + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); +} diff --git a/crates/burn-jit/src/kernel/matmul/tune/key.rs b/crates/burn-jit/src/kernel/matmul/tune/key.rs index 5e830d2a79..d25cce3023 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/key.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/key.rs @@ -1,42 +1,28 @@ -use crate::tune::anchor; +use crate::{tensor::JitTensor, FloatElement, JitAutotuneKey, JitRuntime}; use burn_tensor::{DType, Shape}; use core::fmt::Debug; +use cubecl::AutotuneKey; use serde::{Deserialize, Serialize}; -use std::{cmp::max, fmt::Display, hash::Hash}; +use std::{cmp::max, hash::Hash}; -#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] /// Autotune key representative of matmul versions pub struct MatmulAutotuneKey { round: bool, // True when all matmul dims are multiples of 64 broadcast: bool, // True when there are differences in batch size - anchored_m: usize, - anchored_k: usize, - anchored_n: usize, - anchored_batch: usize, + #[autotune(anchor)] + m: usize, + #[autotune(anchor)] + k: usize, + #[autotune(anchor)] + n: usize, + #[autotune(anchor(max = 256))] + batch: usize, dtype: DType, } -impl Display for MatmulAutotuneKey { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - "Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?} dtype:{:?}", - self.round, - self.broadcast, - self.anchored_m, - self.anchored_k, - self.anchored_n, - self.anchored_batch, - self.dtype - ) - .as_str(), - ) - } -} - impl MatmulAutotuneKey { - /// Create a matmul autotune key from the input shapes - pub fn new(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { + fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { let ndims = lhs_shape.num_dims(); let m = lhs_shape.dims[ndims - 2]; let k = lhs_shape.dims[ndims - 1]; @@ -57,18 +43,22 @@ impl MatmulAutotuneKey { let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0; - Self { - round, - broadcast, - anchored_m: anchor(m, None), - anchored_k: anchor(k, None), - anchored_n: anchor(n, None), - anchored_batch: anchor(batch_product, Some(256)), - dtype, - } + Self::new(round, broadcast, m, k, n, batch_product, dtype) } } +pub(crate) fn create_key( + lhs: &JitTensor, + rhs: &JitTensor, + _out: &JitTensor, +) -> JitAutotuneKey { + JitAutotuneKey::Matmul(MatmulAutotuneKey::from_shape( + &lhs.shape, + &rhs.shape, + E::dtype(), + )) +} + #[cfg(test)] mod tests { use super::*; @@ -77,35 +67,35 @@ mod tests { fn matmul_autotune_key_all_same_and_round() { let lhs_shape: Shape = [4, 512, 512].into(); let rhs_shape: Shape = [4, 512, 512].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); + let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32); assert!(key.round); assert!(!key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 512); + assert_eq!(key.m, 512); + assert_eq!(key.k, 512); + assert_eq!(key.n, 512); } #[test] fn matmul_autotune_key_all_different() { let lhs_shape: Shape = [2, 3, 511, 512].into(); let rhs_shape: Shape = [3, 2, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); + let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32); assert!(!key.round); assert!(key.broadcast); - assert!(key.anchored_m == 512); - assert!(key.anchored_k == 512); - assert!(key.anchored_n == 1024); - assert!(key.anchored_batch == 8); + assert_eq!(key.m, 512); + assert_eq!(key.k, 512); + assert_eq!(key.n, 1024); + assert_eq!(key.batch, 8); } #[test] fn matmul_autotune_key_large_batch() { let lhs_shape: Shape = [128, 512, 511, 512].into(); let rhs_shape: Shape = [200, 400, 512, 513].into(); - let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32); + let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32); - assert!(key.anchored_batch == 256); + assert_eq!(key.batch, 256); } } diff --git a/crates/burn-jit/src/kernel/matmul/tune/mod.rs b/crates/burn-jit/src/kernel/matmul/tune/mod.rs index 1620051a40..c1f4943379 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/mod.rs @@ -3,5 +3,5 @@ mod base; mod key; #[cfg(feature = "autotune")] -pub use base::*; +pub use base::matmul_autotune; pub use key::*; diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index ba953ae0d0..acf69d9aec 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -14,8 +14,6 @@ pub mod kernel; /// Tensor module. pub mod tensor; -pub(crate) mod tune; - /// Elements for JIT backend pub mod element; diff --git a/crates/burn-jit/src/tune.rs b/crates/burn-jit/src/tune.rs deleted file mode 100644 index b289f0cc9c..0000000000 --- a/crates/burn-jit/src/tune.rs +++ /dev/null @@ -1,16 +0,0 @@ -//! Module with tune utilities. - -use std::cmp::min; - -/// Anchor a number to a power of 2. -/// -/// Useful when creating autotune keys. -pub fn anchor(x: usize, max: Option) -> usize { - let exp = f32::ceil(f32::log2(x as f32)) as u32; - let power_of_2 = 2_u32.pow(exp) as usize; - if let Some(max) = max { - min(power_of_2, max) - } else { - power_of_2 - } -}