Skip to content

Commit

Permalink
Feat/fusing (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 3, 2025
1 parent ff94be8 commit 276b8db
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
14 changes: 10 additions & 4 deletions crates/cubecl-core/src/frontend/container/registry/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl From<ExpandElementTyped<u32>> for u32 {
}
}

impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
impl<K: PartialOrd + Ord + core::fmt::Debug, V: CubeType + Clone> Registry<K, V> {
/// Create a new registry.
pub fn new() -> Self {
Self::default()
Expand All @@ -51,7 +51,10 @@ impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
let key = query.into();
let map = self.map.as_ref().borrow();

map.get(&key).unwrap().clone()
match map.get(&key) {
Some(val) => val.clone(),
None => panic!("No value found for key {key:?}"),
}
}

/// Insert an item in the registry.
Expand Down Expand Up @@ -88,12 +91,15 @@ impl<K: PartialOrd + Ord, V: CubeType + Clone> Registry<K, V> {
}
}

impl<K: PartialOrd + Ord, V: Clone> Registry<K, V> {
impl<K: PartialOrd + Ord + core::fmt::Debug, V: Clone> Registry<K, V> {
/// Expand method of [Self::find].
pub fn __expand_find_method(&self, _scope: &mut Scope, key: K) -> V {
let map = self.map.as_ref().borrow();

map.get(&key).unwrap().clone()
match map.get(&key) {
Some(val) => val.clone(),
None => panic!("No value found for key {key:?}"),
}
}

/// Expand method of [Self::insert].
Expand Down
14 changes: 13 additions & 1 deletion crates/cubecl-core/src/frontend/container/sequence/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{cell::RefCell, rc::Rc};
/// All methods [push](Sequence::push), [index](Sequence::index) and
/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
/// on the generated kernel.
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct Sequence<T: CubeType> {
values: Vec<T>,
}
Expand All @@ -30,6 +30,14 @@ impl<T: CubeType> Init for Sequence<T> {
}
}

impl<T: CubeType + Clone> Sequence<T> {
pub fn rev(&self) -> Self {
Self {
values: self.values.iter().rev().cloned().collect(),
}
}
}

impl<T: CubeType> Sequence<T> {
/// Create a new empty sequence.
pub fn new() -> Self {
Expand Down Expand Up @@ -170,6 +178,10 @@ impl<T: CubeType> CubeType for Sequence<T> {
}

impl<T: CubeType> SequenceExpand<T> {
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> u32 {
self.values.borrow().len() as u32
}
/// Expand method of [push](Sequence::push).
pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
self.values.borrow_mut().push(value);
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-core/src/frontend/element/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ macro_rules! declare_uint {
}
}

impl Init for $primitive {
fn init(self, _scope: &mut Scope) -> Self {
self
}
}

impl IntoRuntime for $primitive {
fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
let expand: ExpandElementTyped<Self> = self.into();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ impl<MP: MatmulPrecision, GMM: GlobalMatmul<MP>, C: CubeDispatch> BatchMatmul<MP
let k_range = (0, lhs.shape(rank - 1));

let gmm_config = config.to_gmm_config();

gmm_execute::<MP, GMM>(
lhs,
rhs,
Expand Down

0 comments on commit 276b8db

Please sign in to comment.