From 8a1da89de7d8ff801523683887e647bf617d7ff0 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 25 Jan 2025 10:58:51 -0500 Subject: [PATCH 1/5] Fusion WIP --- .../src/frontend/container/registry/base.rs | 14 ++++++++++---- .../src/frontend/container/sequence/base.rs | 14 +++++++++++++- crates/cubecl-core/src/frontend/element/uint.rs | 6 ++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-core/src/frontend/container/registry/base.rs b/crates/cubecl-core/src/frontend/container/registry/base.rs index b7d3b125b..7b63d1270 100644 --- a/crates/cubecl-core/src/frontend/container/registry/base.rs +++ b/crates/cubecl-core/src/frontend/container/registry/base.rs @@ -27,7 +27,7 @@ impl From> for u32 { } } -impl Registry { +impl Registry { /// Create a new registry. pub fn new() -> Self { Self::default() @@ -49,7 +49,10 @@ impl Registry { let key = query.into(); let map = self.map.as_ref().borrow(); - map.get(&key).unwrap().clone() + match map.get(&key) { + Some(val) => val.clone(), + None => panic!("No value found for key {key:?}"), + } } /// Insert an item in the registry. @@ -86,12 +89,15 @@ impl Registry { } } -impl Registry { +impl Registry { /// Expand method of [Self::find]. pub fn __expand_find_method(&self, _context: &mut CubeContext, key: K) -> V { let map = self.map.as_ref().borrow(); - map.get(&key).unwrap().clone() + match map.get(&key) { + Some(val) => val.clone(), + None => panic!("No value found for key {key:?}"), + } } /// Expand method of [Self::insert]. diff --git a/crates/cubecl-core/src/frontend/container/sequence/base.rs b/crates/cubecl-core/src/frontend/container/sequence/base.rs index c81832c5d..e43c011ee 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/base.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/base.rs @@ -14,7 +14,7 @@ use std::{cell::RefCell, rc::Rc}; /// All methods [push](Sequence::push), [index](Sequence::index) and /// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead /// on the generated kernel. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct Sequence { values: Vec, } @@ -31,6 +31,14 @@ impl Init for Sequence { } } +impl Sequence { + pub fn rev(&self) -> Self { + Self { + values: self.values.iter().rev().map(|a| a.clone()).collect(), + } + } +} + impl Sequence { /// Create a new empty sequence. pub fn new() -> Self { @@ -179,6 +187,10 @@ impl CubeType for Sequence { } impl SequenceExpand { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { + self.values.borrow().len() as u32 + } /// Expand method of [push](Sequence::push). pub fn __expand_push_method(&mut self, _context: &mut CubeContext, value: T::ExpandType) { self.values.borrow_mut().push(value); diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 9ce9f7e07..3d99fb19c 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -31,6 +31,12 @@ macro_rules! declare_uint { } } + impl Init for $primitive { + fn init(self, _context: &mut CubeContext) -> Self { + self + } + } + impl IntoRuntime for $primitive { fn __expand_runtime_method( self, From 54b69f2d947dc6a04594196ada37c5547d55abe6 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 28 Jan 2025 15:51:00 -0500 Subject: [PATCH 2/5] Better name --- crates/cubecl-linalg/src/tensor/contiguous.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 264799799..5a6183f2e 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -17,8 +17,8 @@ pub fn index_offset_with_layout( #[unroll(unroll)] for i in dim_start..dim_end { - let ogwl = offset_ref / layout.stride(i); - offset += ogwl % tensor.shape(i) * tensor.stride(i); + let coordinate_broadcasted = (offset_ref / layout.stride(i)) % tensor.shape(i); + offset += coordinate_broadcasted * tensor.stride(i); } offset / tensor.line_size() From 1094d3d03a0a7cf10250a5b6b0b74ab5e8d4202b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 09:37:11 -0500 Subject: [PATCH 3/5] WiP --- .../cubecl-linalg/src/matmul/components/batch/one_to_one.rs | 1 + crates/cubecl-linalg/src/tensor/contiguous.rs | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs index e3ba06ae6..221d7f62d 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs @@ -108,6 +108,7 @@ impl, C: CubeDispatch> BatchMatmul( lhs, rhs, diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 5a6183f2e..264799799 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -17,8 +17,8 @@ pub fn index_offset_with_layout( #[unroll(unroll)] for i in dim_start..dim_end { - let coordinate_broadcasted = (offset_ref / layout.stride(i)) % tensor.shape(i); - offset += coordinate_broadcasted * tensor.stride(i); + let ogwl = offset_ref / layout.stride(i); + offset += ogwl % tensor.shape(i) * tensor.stride(i); } offset / tensor.line_size() From 95e526fd71bf8401d112773ddec7f047ec68f50f Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 09:42:58 -0500 Subject: [PATCH 4/5] Fix macro --- crates/cubecl-core/src/frontend/element/uint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 872263bda..e716cd3d2 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -29,7 +29,7 @@ macro_rules! declare_uint { } impl Init for $primitive { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut Scope) -> Self { self } } From 6c2507fea511a4b123528d23f98f625042b894a7 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 11:33:03 -0500 Subject: [PATCH 5/5] Clippy --- crates/cubecl-core/src/frontend/container/sequence/base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-core/src/frontend/container/sequence/base.rs b/crates/cubecl-core/src/frontend/container/sequence/base.rs index c87f5784b..06fbf147b 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/base.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/base.rs @@ -33,7 +33,7 @@ impl Init for Sequence { impl Sequence { pub fn rev(&self) -> Self { Self { - values: self.values.iter().rev().map(|a| a.clone()).collect(), + values: self.values.iter().rev().cloned().collect(), } } }