diff --git a/crates/cubecl-core/src/frontend/container/registry/base.rs b/crates/cubecl-core/src/frontend/container/registry/base.rs index 63a798a79..889b9ae4f 100644 --- a/crates/cubecl-core/src/frontend/container/registry/base.rs +++ b/crates/cubecl-core/src/frontend/container/registry/base.rs @@ -29,7 +29,7 @@ impl From> for u32 { } } -impl Registry { +impl Registry { /// Create a new registry. pub fn new() -> Self { Self::default() @@ -51,7 +51,10 @@ impl Registry { let key = query.into(); let map = self.map.as_ref().borrow(); - map.get(&key).unwrap().clone() + match map.get(&key) { + Some(val) => val.clone(), + None => panic!("No value found for key {key:?}"), + } } /// Insert an item in the registry. @@ -88,12 +91,15 @@ impl Registry { } } -impl Registry { +impl Registry { /// Expand method of [Self::find]. pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V { let map = self.map.as_ref().borrow(); - map.get(&key).unwrap().clone() + match map.get(&key) { + Some(val) => val.clone(), + None => panic!("No value found for key {key:?}"), + } } /// Expand method of [Self::insert]. diff --git a/crates/cubecl-core/src/frontend/container/sequence/base.rs b/crates/cubecl-core/src/frontend/container/sequence/base.rs index fa41549e1..06fbf147b 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/base.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/base.rs @@ -13,7 +13,7 @@ use std::{cell::RefCell, rc::Rc}; /// All methods [push](Sequence::push), [index](Sequence::index) and /// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead /// on the generated kernel. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct Sequence { values: Vec, } @@ -30,6 +30,14 @@ impl Init for Sequence { } } +impl Sequence { + pub fn rev(&self) -> Self { + Self { + values: self.values.iter().rev().cloned().collect(), + } + } +} + impl Sequence { /// Create a new empty sequence. pub fn new() -> Self { @@ -170,6 +178,10 @@ impl CubeType for Sequence { } impl SequenceExpand { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { + self.values.borrow().len() as u32 + } /// Expand method of [push](Sequence::push). pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) { self.values.borrow_mut().push(value); diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 8243025ec..e716cd3d2 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -28,6 +28,12 @@ macro_rules! declare_uint { } } + impl Init for $primitive { + fn init(self, _scope: &mut Scope) -> Self { + self + } + } + impl IntoRuntime for $primitive { fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped { let expand: ExpandElementTyped = self.into(); diff --git a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs index e3ba06ae6..221d7f62d 100644 --- a/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs +++ b/crates/cubecl-linalg/src/matmul/components/batch/one_to_one.rs @@ -108,6 +108,7 @@ impl, C: CubeDispatch> BatchMatmul( lhs, rhs,