diff --git a/crates/cubecl-common/src/benchmark.rs b/crates/cubecl-common/src/benchmark.rs index f15ba285d..90aa64cd5 100644 --- a/crates/cubecl-common/src/benchmark.rs +++ b/crates/cubecl-common/src/benchmark.rs @@ -22,8 +22,8 @@ pub enum TimingMethod { impl Display for TimingMethod { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - TimingMethod::Full => f.write_str("full"), - TimingMethod::DeviceOnly => f.write_str("device_only"), + Self::Full => f.write_str("full"), + Self::DeviceOnly => f.write_str("device_only"), } } } diff --git a/crates/cubecl-core/src/codegen/integrator.rs b/crates/cubecl-core/src/codegen/integrator.rs index add23f624..f90979979 100644 --- a/crates/cubecl-core/src/codegen/integrator.rs +++ b/crates/cubecl-core/src/codegen/integrator.rs @@ -220,11 +220,8 @@ impl InputInfo { #[allow(dead_code)] pub fn item(&self) -> Item { match self { - InputInfo::Array { - item, - visibility: _, - } => *item, - InputInfo::Scalar { elem, size: _ } => Item::new(*elem), + Self::Array { item, .. } => *item, + Self::Scalar { elem, .. } => Item::new(*elem), } } } @@ -234,18 +231,9 @@ impl OutputInfo { #[allow(dead_code)] pub fn item(&self) -> Item { match self { - OutputInfo::ArrayWrite { - item, - local: _, - position: _, - } => *item, - OutputInfo::InputArrayWrite { - item, - input: _, - local: _, - position: _, - } => *item, - OutputInfo::Array { item } => *item, + Self::ArrayWrite { item, .. } + | Self::InputArrayWrite { item, .. } + | Self::Array { item } => *item, } } } @@ -278,18 +266,9 @@ impl OutputInfo { #[allow(dead_code)] pub fn elem_size(&self) -> usize { let elem = match self { - OutputInfo::ArrayWrite { - item, - local: _, - position: _, - } => bool_elem(item.elem()), - OutputInfo::InputArrayWrite { - item, - input: _, - local: _, - position: _, - } => bool_elem(item.elem()), - OutputInfo::Array { item } => bool_elem(item.elem()), + Self::ArrayWrite { item, .. } + | Self::InputArrayWrite { item, .. } + | Self::Array { item } => bool_elem(item.elem()), }; ::elem_size(elem) } @@ -464,10 +443,7 @@ impl KernelIntegrator { let (item, local, position) = match output { OutputInfo::ArrayWrite { item, local, position } => (item, local, position), OutputInfo::InputArrayWrite { - item: _, - input, - local: _, - position: _, + input, .. } => { assert_eq!( *input, mapping.pos_input as u16, @@ -475,7 +451,7 @@ impl KernelIntegrator { ); return; } - OutputInfo::Array { item: _ } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."), + OutputInfo::Array { .. } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."), }; let item = match self.input_bindings.get_mut(mapping.pos_input) { diff --git a/crates/cubecl-core/src/compute/builder.rs b/crates/cubecl-core/src/compute/builder.rs index ed69d4f6e..a7f365215 100644 --- a/crates/cubecl-core/src/compute/builder.rs +++ b/crates/cubecl-core/src/compute/builder.rs @@ -23,7 +23,7 @@ impl KernelBuilder { pub fn scalar(&mut self, elem: Elem) -> ExpandElement { let index = match self.indices.get_mut(&elem) { Some(index) => match self.inputs.get_mut(*index).unwrap() { - InputInfo::Scalar { elem: _, size } => { + InputInfo::Scalar { size, .. } => { *size += 1; *size as u16 - 1 } diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 6d5783a08..75ae69ee0 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -179,8 +179,8 @@ pub enum ScalarState { impl TensorState { /// Push a new tensor to the state. pub fn push(&mut self, tensor: &TensorHandleRef<'_, R>) { - if let TensorState::Empty = self { - *self = TensorState::Some { + if let Self::Empty = self { + *self = Self::Some { bindings: Vec::with_capacity(1), metadata: Vec::new(), lengths: Vec::new(), @@ -189,12 +189,12 @@ impl TensorState { }; let (bindings, metadata, lengths) = match self { - TensorState::Empty => panic!("Should be init"), - TensorState::Some { + Self::Empty => panic!("Should be init"), + Self::Some { bindings, metadata, lengths, - runtime: _, + .. } => (bindings, metadata, lengths), }; @@ -277,7 +277,7 @@ impl TensorState { bindings, mut metadata, lengths, - runtime: _, + .. } = self { if R::require_array_lengths() { @@ -296,8 +296,8 @@ impl ScalarState { /// Add a new scalar value to the state. pub fn push(&mut self, val: T) { match self { - ScalarState::Empty => *self = Self::Some(vec![val]), - ScalarState::Some(values) => values.push(val), + Self::Empty => *self = Self::Some(vec![val]), + Self::Some(values) => values.push(val), } } @@ -307,8 +307,8 @@ impl ScalarState { bindings: &mut Vec, ) { match self { - ScalarState::Empty => (), - ScalarState::Some(values) => { + Self::Empty => (), + Self::Some(values) => { let handle = client.create(bytemuck::cast_slice(values)); bindings.push(handle.binding()); } diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index 3480e4206..e5561e7ea 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -42,7 +42,7 @@ pub struct RangeExpand { impl RangeExpand { pub fn new(start: ExpandElementTyped, end: ExpandElementTyped, inclusive: bool) -> Self { - RangeExpand { + Self { start, end, inclusive, diff --git a/crates/cubecl-core/src/frontend/container/array/base.rs b/crates/cubecl-core/src/frontend/container/array/base.rs index 493a526e2..e4f0e2c2c 100644 --- a/crates/cubecl-core/src/frontend/container/array/base.rs +++ b/crates/cubecl-core/src/frontend/container/array/base.rs @@ -28,12 +28,12 @@ mod new { /// Create a new array of the given length. #[allow(unused_variables)] pub fn new(length: L) -> Self { - Array { _val: PhantomData } + Self { _val: PhantomData } } /// Create an array from data. pub fn from_data(_data: impl IntoIterator) -> Self { - Array { _val: PhantomData } + Self { _val: PhantomData } } /// Expand function of [new](Array::new). diff --git a/crates/cubecl-core/src/frontend/container/array/launch.rs b/crates/cubecl-core/src/frontend/container/array/launch.rs index 862801886..bcf3fdb40 100644 --- a/crates/cubecl-core/src/frontend/container/array/launch.rs +++ b/crates/cubecl-core/src/frontend/container/array/launch.rs @@ -65,11 +65,7 @@ pub enum ArrayArg<'a, R: Runtime> { impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { fn register(&self, launcher: &mut KernelLauncher) { - if let ArrayArg::Handle { - handle, - vectorization_factor: _, - } = self - { + if let Self::Handle { handle, .. } = self { launcher.register_array(handle) } } @@ -129,8 +125,8 @@ impl LaunchArg for Array { fn compilation_arg(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg { match runtime_arg { ArrayArg::Handle { - handle: _, vectorization_factor, + .. } => ArrayCompilationArg { inplace: None, vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()), diff --git a/crates/cubecl-core/src/frontend/container/shared_memory.rs b/crates/cubecl-core/src/frontend/container/shared_memory.rs index b172b07c0..632a073a7 100644 --- a/crates/cubecl-core/src/frontend/container/shared_memory.rs +++ b/crates/cubecl-core/src/frontend/container/shared_memory.rs @@ -32,7 +32,7 @@ impl CubeType for SharedMemory { impl SharedMemory { pub fn new(_size: S) -> Self { - SharedMemory { _val: PhantomData } + Self { _val: PhantomData } } pub fn new_lined(_size: S, _vectorization_factor: u32) -> SharedMemory> { diff --git a/crates/cubecl-core/src/frontend/container/tensor/launch.rs b/crates/cubecl-core/src/frontend/container/tensor/launch.rs index 5b73f6038..ee4561b63 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/launch.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/launch.rs @@ -82,8 +82,8 @@ impl LaunchArg for Tensor { fn compilation_arg(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg { match runtime_arg { TensorArg::Handle { - handle: _, vectorization_factor, + .. } => TensorCompilationArg { inplace: None, vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()), @@ -127,11 +127,7 @@ impl<'a, R: Runtime> TensorArg<'a, R> { impl<'a, R: Runtime> ArgSettings for TensorArg<'a, R> { fn register(&self, launcher: &mut KernelLauncher) { - if let TensorArg::Handle { - handle, - vectorization_factor: _, - } = self - { + if let Self::Handle { handle, .. } = self { launcher.register_tensor(handle) } } diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 90ca99427..c5c5cb74c 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -321,14 +321,14 @@ impl ExpandElement { /// If the element can be mutated inplace, potentially reusing the register. pub fn can_mut(&self) -> bool { match self { - ExpandElement::Managed(var) => { + Self::Managed(var) => { if let Variable::Local { .. } = var.as_ref() { Rc::strong_count(var) <= 2 } else { false } } - ExpandElement::Plain(_) => false, + Self::Plain(_) => false, } } @@ -343,8 +343,8 @@ impl core::ops::Deref for ExpandElement { fn deref(&self) -> &Self::Target { match self { - ExpandElement::Managed(var) => var.as_ref(), - ExpandElement::Plain(var) => var, + Self::Managed(var) => var.as_ref(), + Self::Plain(var) => var, } } } diff --git a/crates/cubecl-core/src/frontend/element/vectorized.rs b/crates/cubecl-core/src/frontend/element/vectorized.rs index 545820a9e..608f0c508 100644 --- a/crates/cubecl-core/src/frontend/element/vectorized.rs +++ b/crates/cubecl-core/src/frontend/element/vectorized.rs @@ -59,8 +59,8 @@ impl Vectorized for &mut Tensor { impl Vectorized for ExpandElement { fn vectorization_factor(&self) -> u32 { let var = match self { - ExpandElement::Managed(var) => var, - ExpandElement::Plain(var) => var, + Self::Managed(var) => var, + Self::Plain(var) => var, }; var.item().vectorization.map(|it| it.get()).unwrap_or(1) as u32 diff --git a/crates/cubecl-core/src/ir/branch.rs b/crates/cubecl-core/src/ir/branch.rs index 7242befb2..12732b586 100644 --- a/crates/cubecl-core/src/ir/branch.rs +++ b/crates/cubecl-core/src/ir/branch.rs @@ -28,14 +28,14 @@ pub enum Branch { impl Display for Branch { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Branch::If(if_) => write!(f, "if({})", if_.cond), - Branch::IfElse(if_else) => write!(f, "if({})", if_else.cond), - Branch::Select(select) => write!( + Self::If(if_) => write!(f, "if({})", if_.cond), + Self::IfElse(if_else) => write!(f, "if({})", if_else.cond), + Self::Select(select) => write!( f, "{} = select({}, {}, {})", select.out, select.cond, select.then, select.or_else ), - Branch::Switch(switch) => write!( + Self::Switch(switch) => write!( f, "switch({}) {:?}", switch.value, @@ -45,7 +45,7 @@ impl Display for Branch { .map(|case| format!("{}", case.0)) .collect::>(), ), - Branch::RangeLoop(range_loop) => write!( + Self::RangeLoop(range_loop) => write!( f, "for({} in {}{}{})", range_loop.i, @@ -53,9 +53,9 @@ impl Display for Branch { if range_loop.inclusive { "..=" } else { ".." }, range_loop.end ), - Branch::Loop(_) => write!(f, "loop{{}}"), - Branch::Return => write!(f, "return"), - Branch::Break => write!(f, "break"), + Self::Loop(_) => write!(f, "loop{{}}"), + Self::Return => write!(f, "return"), + Self::Break => write!(f, "break"), } } } diff --git a/crates/cubecl-core/src/ir/cmma.rs b/crates/cubecl-core/src/ir/cmma.rs index d856d4be3..d4bc24de3 100644 --- a/crates/cubecl-core/src/ir/cmma.rs +++ b/crates/cubecl-core/src/ir/cmma.rs @@ -65,8 +65,8 @@ pub enum CoopMma { impl Display for CoopMma { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - CoopMma::Fill { mat, value } => write!(f, "matrix_fill({}, {})", mat, value), - CoopMma::Load { + Self::Fill { mat, value } => write!(f, "matrix_fill({}, {})", mat, value), + Self::Load { mat, value, stride, @@ -81,7 +81,7 @@ impl Display for CoopMma { mat, value, stride ) } - CoopMma::Execute { + Self::Execute { mat_a, mat_b, mat_c, @@ -91,7 +91,7 @@ impl Display for CoopMma { "{} = execute_cmma({}, {}, {})", mat_d, mat_a, mat_b, mat_c ), - CoopMma::Store { + Self::Store { output, mat, stride, diff --git a/crates/cubecl-core/src/ir/kernel.rs b/crates/cubecl-core/src/ir/kernel.rs index e995c915c..0baa1ca7d 100644 --- a/crates/cubecl-core/src/ir/kernel.rs +++ b/crates/cubecl-core/src/ir/kernel.rs @@ -61,12 +61,12 @@ impl Elem { /// The output will have the same type as the element. pub fn constant_from_f64(&self, val: f64) -> Variable { Variable::ConstantScalar(match self { - Elem::Float(kind) => ConstantScalarValue::Float(val, *kind), - Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::UInt => ConstantScalarValue::UInt(val as u64), - Elem::Bool => ConstantScalarValue::Bool(val > 0.0), - Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64), + Self::Float(kind) => ConstantScalarValue::Float(val, *kind), + Self::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::UInt => ConstantScalarValue::UInt(val as u64), + Self::Bool => ConstantScalarValue::Bool(val > 0.0), + Self::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::AtomicUInt => ConstantScalarValue::UInt(val as u64), }) } /// Create a constant scalar from a signed integer. @@ -74,12 +74,12 @@ impl Elem { /// The output will have the same type as the element. pub fn constant_from_i64(&self, val: i64) -> Variable { Variable::ConstantScalar(match self { - Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind), - Elem::Int(kind) => ConstantScalarValue::Int(val, *kind), - Elem::UInt => ConstantScalarValue::UInt(val as u64), - Elem::Bool => ConstantScalarValue::Bool(val > 0), - Elem::AtomicInt(kind) => ConstantScalarValue::Int(val, *kind), - Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64), + Self::Float(kind) => ConstantScalarValue::Float(val as f64, *kind), + Self::Int(kind) => ConstantScalarValue::Int(val, *kind), + Self::UInt => ConstantScalarValue::UInt(val as u64), + Self::Bool => ConstantScalarValue::Bool(val > 0), + Self::AtomicInt(kind) => ConstantScalarValue::Int(val, *kind), + Self::AtomicUInt => ConstantScalarValue::UInt(val as u64), }) } /// Create a constant scalar from a unsigned integer. @@ -87,12 +87,12 @@ impl Elem { /// The output will have the same type as the element. pub fn constant_from_u64(&self, val: u64) -> Variable { Variable::ConstantScalar(match self { - Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind), - Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::UInt => ConstantScalarValue::UInt(val), - Elem::Bool => ConstantScalarValue::Bool(val > 0), - Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::AtomicUInt => ConstantScalarValue::UInt(val), + Self::Float(kind) => ConstantScalarValue::Float(val as f64, *kind), + Self::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::UInt => ConstantScalarValue::UInt(val), + Self::Bool => ConstantScalarValue::Bool(val > 0), + Self::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::AtomicUInt => ConstantScalarValue::UInt(val), }) } /// Create a constant scalar from a boolean. @@ -100,12 +100,12 @@ impl Elem { /// The output will have the same type as the element. pub fn constant_from_bool(&self, val: bool) -> Variable { Variable::ConstantScalar(match self { - Elem::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind), - Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), - Elem::UInt => ConstantScalarValue::UInt(val as u64), - Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64), - Elem::Bool => ConstantScalarValue::Bool(val), + Self::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind), + Self::Int(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind), + Self::UInt => ConstantScalarValue::UInt(val as u64), + Self::AtomicUInt => ConstantScalarValue::UInt(val as u64), + Self::Bool => ConstantScalarValue::Bool(val), }) } @@ -126,34 +126,34 @@ impl Elem { /// Get the size in bytes. pub fn size(&self) -> usize { match self { - Elem::Float(kind) => match kind { + Self::Float(kind) => match kind { FloatKind::F16 => core::mem::size_of::(), FloatKind::BF16 => core::mem::size_of::(), FloatKind::F32 => core::mem::size_of::(), FloatKind::F64 => core::mem::size_of::(), }, - Elem::Int(kind) => match kind { + Self::Int(kind) => match kind { IntKind::I32 => core::mem::size_of::(), IntKind::I64 => core::mem::size_of::(), }, - Elem::AtomicInt(kind) => match kind { + Self::AtomicInt(kind) => match kind { IntKind::I32 => core::mem::size_of::(), IntKind::I64 => core::mem::size_of::(), }, - Elem::UInt => core::mem::size_of::(), - Elem::AtomicUInt => core::mem::size_of::(), - Elem::Bool => core::mem::size_of::(), + Self::UInt => core::mem::size_of::(), + Self::AtomicUInt => core::mem::size_of::(), + Self::Bool => core::mem::size_of::(), } } pub fn is_atomic(&self) -> bool { - matches!(self, Elem::AtomicInt(_) | Elem::AtomicUInt) + matches!(self, Self::AtomicInt(_) | Self::AtomicUInt) } pub fn is_int(&self) -> bool { matches!( self, - Elem::Int(_) | Elem::AtomicInt(_) | Elem::UInt | Elem::AtomicUInt + Self::Int(_) | Self::AtomicInt(_) | Self::UInt | Self::AtomicUInt ) } } diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index ac4450891..bb05228bc 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -25,12 +25,12 @@ pub enum Operation { impl Display for Operation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Operation::Operator(operator) => write!(f, "{operator}"), - Operation::Metadata(metadata) => write!(f, "{metadata}"), - Operation::Branch(branch) => write!(f, "{branch}"), - Operation::Synchronization(synchronization) => write!(f, "{synchronization}"), - Operation::Subcube(subcube) => write!(f, "{subcube}"), - Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"), + Self::Operator(operator) => write!(f, "{operator}"), + Self::Metadata(metadata) => write!(f, "{metadata}"), + Self::Branch(branch) => write!(f, "{branch}"), + Self::Synchronization(synchronization) => write!(f, "{synchronization}"), + Self::Subcube(subcube) => write!(f, "{subcube}"), + Self::CoopMma(coop_mma) => write!(f, "{coop_mma}"), } } } @@ -38,13 +38,13 @@ impl Display for Operation { impl Operation { pub fn out(&self) -> Option { match self { - Operation::Operator(operator) => operator.out(), - Operation::Metadata(metadata) => metadata.out(), - Operation::Branch(Branch::Select(op)) => Some(op.out), - Operation::Branch(_) => None, - Operation::Synchronization(_) => None, - Operation::Subcube(subcube) => subcube.out(), - Operation::CoopMma(_) => None, + Self::Operator(operator) => operator.out(), + Self::Metadata(metadata) => metadata.out(), + Self::Branch(Branch::Select(op)) => Some(op.out), + Self::Branch(_) => None, + Self::Synchronization(_) => None, + Self::Subcube(subcube) => subcube.out(), + Self::CoopMma(_) => None, } } } @@ -120,165 +120,165 @@ pub enum Operator { impl Operator { pub fn out(&self) -> Option { - let val = match self { - Operator::Add(binary_operator) - | Operator::Sub(binary_operator) - | Operator::Mul(binary_operator) - | Operator::Div(binary_operator) - | Operator::Powf(binary_operator) - | Operator::Equal(binary_operator) - | Operator::NotEqual(binary_operator) - | Operator::Lower(binary_operator) - | Operator::Greater(binary_operator) - | Operator::LowerEqual(binary_operator) - | Operator::GreaterEqual(binary_operator) - | Operator::Modulo(binary_operator) - | Operator::Index(binary_operator) - | Operator::UncheckedIndex(binary_operator) - | Operator::IndexAssign(binary_operator) - | Operator::UncheckedIndexAssign(binary_operator) - | Operator::Max(binary_operator) - | Operator::Min(binary_operator) - | Operator::BitwiseAnd(binary_operator) - | Operator::BitwiseOr(binary_operator) - | Operator::BitwiseXor(binary_operator) - | Operator::ShiftLeft(binary_operator) - | Operator::ShiftRight(binary_operator) - | Operator::Remainder(binary_operator) - | Operator::And(binary_operator) - | Operator::Or(binary_operator) - | Operator::AtomicSwap(binary_operator) - | Operator::AtomicAdd(binary_operator) - | Operator::AtomicSub(binary_operator) - | Operator::AtomicMax(binary_operator) - | Operator::AtomicMin(binary_operator) - | Operator::AtomicAnd(binary_operator) - | Operator::AtomicOr(binary_operator) - | Operator::AtomicXor(binary_operator) - | Operator::Dot(binary_operator) => binary_operator.out, + match self { + Self::Add(binary_operator) + | Self::Sub(binary_operator) + | Self::Mul(binary_operator) + | Self::Div(binary_operator) + | Self::Powf(binary_operator) + | Self::Equal(binary_operator) + | Self::NotEqual(binary_operator) + | Self::Lower(binary_operator) + | Self::Greater(binary_operator) + | Self::LowerEqual(binary_operator) + | Self::GreaterEqual(binary_operator) + | Self::Modulo(binary_operator) + | Self::Index(binary_operator) + | Self::UncheckedIndex(binary_operator) + | Self::IndexAssign(binary_operator) + | Self::UncheckedIndexAssign(binary_operator) + | Self::Max(binary_operator) + | Self::Min(binary_operator) + | Self::BitwiseAnd(binary_operator) + | Self::BitwiseOr(binary_operator) + | Self::BitwiseXor(binary_operator) + | Self::ShiftLeft(binary_operator) + | Self::ShiftRight(binary_operator) + | Self::Remainder(binary_operator) + | Self::And(binary_operator) + | Self::Or(binary_operator) + | Self::AtomicSwap(binary_operator) + | Self::AtomicAdd(binary_operator) + | Self::AtomicSub(binary_operator) + | Self::AtomicMax(binary_operator) + | Self::AtomicMin(binary_operator) + | Self::AtomicAnd(binary_operator) + | Self::AtomicOr(binary_operator) + | Self::AtomicXor(binary_operator) + | Self::Dot(binary_operator) => binary_operator.out, - Operator::Abs(unary_operator) - | Operator::Exp(unary_operator) - | Operator::Log(unary_operator) - | Operator::Log1p(unary_operator) - | Operator::Cos(unary_operator) - | Operator::Sin(unary_operator) - | Operator::Tanh(unary_operator) - | Operator::Sqrt(unary_operator) - | Operator::Round(unary_operator) - | Operator::Floor(unary_operator) - | Operator::Ceil(unary_operator) - | Operator::Erf(unary_operator) - | Operator::Recip(unary_operator) - | Operator::Assign(unary_operator) - | Operator::Not(unary_operator) - | Operator::Neg(unary_operator) - | Operator::Bitcast(unary_operator) - | Operator::AtomicLoad(unary_operator) - | Operator::AtomicStore(unary_operator) - | Operator::Magnitude(unary_operator) - | Operator::Normalize(unary_operator) => unary_operator.out, + Self::Abs(unary_operator) + | Self::Exp(unary_operator) + | Self::Log(unary_operator) + | Self::Log1p(unary_operator) + | Self::Cos(unary_operator) + | Self::Sin(unary_operator) + | Self::Tanh(unary_operator) + | Self::Sqrt(unary_operator) + | Self::Round(unary_operator) + | Self::Floor(unary_operator) + | Self::Ceil(unary_operator) + | Self::Erf(unary_operator) + | Self::Recip(unary_operator) + | Self::Assign(unary_operator) + | Self::Not(unary_operator) + | Self::Neg(unary_operator) + | Self::Bitcast(unary_operator) + | Self::AtomicLoad(unary_operator) + | Self::AtomicStore(unary_operator) + | Self::Magnitude(unary_operator) + | Self::Normalize(unary_operator) => unary_operator.out, - Operator::Clamp(clamp_operator) => clamp_operator.out, - Operator::Copy(copy_operator) => copy_operator.out, - Operator::CopyBulk(copy_bulk_operator) => copy_bulk_operator.out, - Operator::Slice(slice_operator) => slice_operator.out, - Operator::InitLine(line_init_operator) => line_init_operator.out, - Operator::AtomicCompareAndSwap(op) => op.out, - Operator::Fma(fma_operator) => fma_operator.out, - }; - Some(val) + Self::Clamp(clamp_operator) => clamp_operator.out, + Self::Copy(copy_operator) => copy_operator.out, + Self::CopyBulk(copy_bulk_operator) => copy_bulk_operator.out, + Self::Slice(slice_operator) => slice_operator.out, + Self::InitLine(line_init_operator) => line_init_operator.out, + Self::AtomicCompareAndSwap(op) => op.out, + Self::Fma(fma_operator) => fma_operator.out, + } + .into() } } impl Display for Operator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Operator::Add(op) => write!(f, "{} = {} + {}", op.out, op.lhs, op.rhs), - Operator::Fma(op) => write!(f, "{} = {} * {} + {}", op.out, op.a, op.b, op.c), - Operator::Sub(op) => write!(f, "{} = {} - {}", op.out, op.lhs, op.rhs), - Operator::Mul(op) => write!(f, "{} = {} * {}", op.out, op.lhs, op.rhs), - Operator::Div(op) => write!(f, "{} = {} / {}", op.out, op.lhs, op.rhs), - Operator::Abs(op) => write!(f, "{} = {}.abs()", op.out, op.input), - Operator::Exp(op) => write!(f, "{} = {}.exp()", op.out, op.input), - Operator::Log(op) => write!(f, "{} = {}.log()", op.out, op.input), - Operator::Log1p(op) => write!(f, "{} = {}.log_1p()", op.out, op.input), - Operator::Cos(op) => write!(f, "{} = {}.cos()", op.out, op.input), - Operator::Sin(op) => write!(f, "{} = {}.sin()", op.out, op.input), - Operator::Tanh(op) => write!(f, "{} = {}.tanh()", op.out, op.input), - Operator::Powf(op) => write!(f, "{} = {}.pow({})", op.out, op.lhs, op.rhs), - Operator::Sqrt(op) => write!(f, "{} = {}.sqrt()", op.out, op.input), - Operator::Round(op) => write!(f, "{} = {}.round()", op.out, op.input), - Operator::Floor(op) => write!(f, "{} = {}.floor()", op.out, op.input), - Operator::Ceil(op) => write!(f, "{} = {}.ceil()", op.out, op.input), - Operator::Erf(op) => write!(f, "{} = {}.erf()", op.out, op.input), - Operator::Recip(op) => write!(f, "{} = {}.recip()", op.out, op.input), - Operator::Equal(op) => write!(f, "{} = {} == {}", op.out, op.lhs, op.rhs), - Operator::NotEqual(op) => write!(f, "{} = {} != {}", op.out, op.lhs, op.rhs), - Operator::Lower(op) => write!(f, "{} = {} < {}", op.out, op.lhs, op.rhs), - Operator::Clamp(op) => write!( + Self::Add(op) => write!(f, "{} = {} + {}", op.out, op.lhs, op.rhs), + Self::Fma(op) => write!(f, "{} = {} * {} + {}", op.out, op.a, op.b, op.c), + Self::Sub(op) => write!(f, "{} = {} - {}", op.out, op.lhs, op.rhs), + Self::Mul(op) => write!(f, "{} = {} * {}", op.out, op.lhs, op.rhs), + Self::Div(op) => write!(f, "{} = {} / {}", op.out, op.lhs, op.rhs), + Self::Abs(op) => write!(f, "{} = {}.abs()", op.out, op.input), + Self::Exp(op) => write!(f, "{} = {}.exp()", op.out, op.input), + Self::Log(op) => write!(f, "{} = {}.log()", op.out, op.input), + Self::Log1p(op) => write!(f, "{} = {}.log_1p()", op.out, op.input), + Self::Cos(op) => write!(f, "{} = {}.cos()", op.out, op.input), + Self::Sin(op) => write!(f, "{} = {}.sin()", op.out, op.input), + Self::Tanh(op) => write!(f, "{} = {}.tanh()", op.out, op.input), + Self::Powf(op) => write!(f, "{} = {}.pow({})", op.out, op.lhs, op.rhs), + Self::Sqrt(op) => write!(f, "{} = {}.sqrt()", op.out, op.input), + Self::Round(op) => write!(f, "{} = {}.round()", op.out, op.input), + Self::Floor(op) => write!(f, "{} = {}.floor()", op.out, op.input), + Self::Ceil(op) => write!(f, "{} = {}.ceil()", op.out, op.input), + Self::Erf(op) => write!(f, "{} = {}.erf()", op.out, op.input), + Self::Recip(op) => write!(f, "{} = {}.recip()", op.out, op.input), + Self::Equal(op) => write!(f, "{} = {} == {}", op.out, op.lhs, op.rhs), + Self::NotEqual(op) => write!(f, "{} = {} != {}", op.out, op.lhs, op.rhs), + Self::Lower(op) => write!(f, "{} = {} < {}", op.out, op.lhs, op.rhs), + Self::Clamp(op) => write!( f, "{} = {}.clamp({}, {})", op.out, op.input, op.min_value, op.max_value ), - Operator::Greater(op) => write!(f, "{} = {} > {}", op.out, op.lhs, op.rhs), - Operator::LowerEqual(op) => write!(f, "{} = {} <= {}", op.out, op.lhs, op.rhs), - Operator::GreaterEqual(op) => write!(f, "{} = {} >= {}", op.out, op.lhs, op.rhs), - Operator::Assign(op) => write!(f, "{} = {}", op.out, op.input), - Operator::Modulo(op) => write!(f, "{} = {} % {}", op.out, op.lhs, op.rhs), - Operator::Index(op) => write!(f, "{} = {}[{}]", op.out, op.lhs, op.rhs), - Operator::Copy(op) => write!( + Self::Greater(op) => write!(f, "{} = {} > {}", op.out, op.lhs, op.rhs), + Self::LowerEqual(op) => write!(f, "{} = {} <= {}", op.out, op.lhs, op.rhs), + Self::GreaterEqual(op) => write!(f, "{} = {} >= {}", op.out, op.lhs, op.rhs), + Self::Assign(op) => write!(f, "{} = {}", op.out, op.input), + Self::Modulo(op) => write!(f, "{} = {} % {}", op.out, op.lhs, op.rhs), + Self::Index(op) => write!(f, "{} = {}[{}]", op.out, op.lhs, op.rhs), + Self::Copy(op) => write!( f, "{}[{}] = {}[{}]", op.out, op.out_index, op.input, op.in_index ), - Operator::CopyBulk(op) => write!( + Self::CopyBulk(op) => write!( f, "memcpy({}[{}], {}[{}], {})", op.out, op.input, op.in_index, op.out_index, op.len ), - Operator::Slice(op) => write!(f, "{} = {}[{}..{}]", op.out, op.input, op.start, op.end), - Operator::UncheckedIndex(op) => { + Self::Slice(op) => write!(f, "{} = {}[{}..{}]", op.out, op.input, op.start, op.end), + Self::UncheckedIndex(op) => { write!(f, "{} = unchecked {}[{}]", op.out, op.lhs, op.rhs) } - Operator::IndexAssign(op) => write!(f, "{}[{}] = {}", op.out, op.lhs, op.rhs), - Operator::UncheckedIndexAssign(op) => { + Self::IndexAssign(op) => write!(f, "{}[{}] = {}", op.out, op.lhs, op.rhs), + Self::UncheckedIndexAssign(op) => { write!(f, "unchecked {}[{}] = {}", op.out, op.lhs, op.rhs) } - Operator::And(op) => write!(f, "{} = {} && {}", op.out, op.lhs, op.rhs), - Operator::Or(op) => write!(f, "{} = {} || {}", op.out, op.lhs, op.rhs), - Operator::Not(op) => write!(f, "{} = !{}", op.out, op.input), - Operator::Neg(op) => write!(f, "{} = -{}", op.out, op.input), - Operator::Max(op) => write!(f, "{} = {}.max({})", op.out, op.lhs, op.rhs), - Operator::Min(op) => write!(f, "{} = {}.min({})", op.out, op.lhs, op.rhs), - Operator::BitwiseAnd(op) => write!(f, "{} = {} & {}", op.out, op.lhs, op.rhs), - Operator::BitwiseOr(op) => write!(f, "{} = {} | {}", op.out, op.lhs, op.rhs), - Operator::BitwiseXor(op) => write!(f, "{} = {} ^ {}", op.out, op.lhs, op.rhs), - Operator::ShiftLeft(op) => write!(f, "{} = {} << {}", op.out, op.lhs, op.rhs), - Operator::ShiftRight(op) => write!(f, "{} = {} >> {}", op.out, op.lhs, op.rhs), - Operator::Remainder(op) => write!(f, "{} = {} rem {}", op.out, op.lhs, op.rhs), - Operator::Bitcast(op) => write!(f, "{} = bitcast({})", op.out, op.input), - Operator::AtomicLoad(op) => write!(f, "{} = atomic_load({})", op.out, op.input), - Operator::AtomicStore(op) => write!(f, "atomic_store({}, {})", op.out, op.input), - Operator::AtomicSwap(op) => { + Self::And(op) => write!(f, "{} = {} && {}", op.out, op.lhs, op.rhs), + Self::Or(op) => write!(f, "{} = {} || {}", op.out, op.lhs, op.rhs), + Self::Not(op) => write!(f, "{} = !{}", op.out, op.input), + Self::Neg(op) => write!(f, "{} = -{}", op.out, op.input), + Self::Max(op) => write!(f, "{} = {}.max({})", op.out, op.lhs, op.rhs), + Self::Min(op) => write!(f, "{} = {}.min({})", op.out, op.lhs, op.rhs), + Self::BitwiseAnd(op) => write!(f, "{} = {} & {}", op.out, op.lhs, op.rhs), + Self::BitwiseOr(op) => write!(f, "{} = {} | {}", op.out, op.lhs, op.rhs), + Self::BitwiseXor(op) => write!(f, "{} = {} ^ {}", op.out, op.lhs, op.rhs), + Self::ShiftLeft(op) => write!(f, "{} = {} << {}", op.out, op.lhs, op.rhs), + Self::ShiftRight(op) => write!(f, "{} = {} >> {}", op.out, op.lhs, op.rhs), + Self::Remainder(op) => write!(f, "{} = {} rem {}", op.out, op.lhs, op.rhs), + Self::Bitcast(op) => write!(f, "{} = bitcast({})", op.out, op.input), + Self::AtomicLoad(op) => write!(f, "{} = atomic_load({})", op.out, op.input), + Self::AtomicStore(op) => write!(f, "atomic_store({}, {})", op.out, op.input), + Self::AtomicSwap(op) => { write!(f, "{} = atomic_swap({}, {})", op.out, op.lhs, op.rhs) } - Operator::AtomicAdd(op) => write!(f, "{} = atomic_add({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicSub(op) => write!(f, "{} = atomic_sub({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicMax(op) => write!(f, "{} = atomic_max({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicMin(op) => write!(f, "{} = atomic_min({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicAnd(op) => write!(f, "{} = atomic_and({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicOr(op) => write!(f, "{} = atomic_or({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicXor(op) => write!(f, "{} = atomic_xor({}, {})", op.out, op.lhs, op.rhs), - Operator::AtomicCompareAndSwap(op) => write!( + Self::AtomicAdd(op) => write!(f, "{} = atomic_add({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicSub(op) => write!(f, "{} = atomic_sub({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicMax(op) => write!(f, "{} = atomic_max({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicMin(op) => write!(f, "{} = atomic_min({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicAnd(op) => write!(f, "{} = atomic_and({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicOr(op) => write!(f, "{} = atomic_or({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicXor(op) => write!(f, "{} = atomic_xor({}, {})", op.out, op.lhs, op.rhs), + Self::AtomicCompareAndSwap(op) => write!( f, "{} = compare_and_swap({}, {}, {})", op.out, op.input, op.cmp, op.val ), - Operator::Magnitude(op) => write!(f, "{} = {}.length()", op.out, op.input), - Operator::Normalize(op) => write!(f, "{} = {}.normalize()", op.out, op.input), - Operator::Dot(op) => write!(f, "{} = {}.dot({})", op.out, op.lhs, op.rhs), - Operator::InitLine(init) => { + Self::Magnitude(op) => write!(f, "{} = {}.length()", op.out, op.input), + Self::Normalize(op) => write!(f, "{} = {}.normalize()", op.out, op.input), + Self::Dot(op) => write!(f, "{} = {}.dot({})", op.out, op.lhs, op.rhs), + Self::InitLine(init) => { let inits = init .inputs .iter() @@ -314,21 +314,19 @@ pub enum Metadata { impl Metadata { pub fn out(&self) -> Option { - let val = match self { - Metadata::Stride { out, .. } => *out, - Metadata::Shape { out, .. } => *out, - Metadata::Length { out, .. } => *out, - }; - Some(val) + match self { + Self::Stride { out, .. } | Self::Shape { out, .. } | Self::Length { out, .. } => *out, + } + .into() } } impl Display for Metadata { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Metadata::Stride { dim, var, out } => write!(f, "{} = {}.strides[{}]", out, var, dim), - Metadata::Shape { dim, var, out } => write!(f, "{} = {}.shape[{}]", out, var, dim), - Metadata::Length { var, out } => write!(f, "{} = {}.len()", out, var), + Self::Stride { dim, var, out } => write!(f, "{} = {}.strides[{}]", out, var, dim), + Self::Shape { dim, var, out } => write!(f, "{} = {}.shape[{}]", out, var, dim), + Self::Length { var, out } => write!(f, "{} = {}.len()", out, var), } } } @@ -432,7 +430,7 @@ pub struct FmaOperator { impl From for Operation { fn from(val: Operator) -> Self { - Operation::Operator(val) + Self::Operator(val) } } @@ -450,6 +448,6 @@ impl From for Operation { impl From for Operation { fn from(val: Metadata) -> Self { - Operation::Metadata(val) + Self::Metadata(val) } } diff --git a/crates/cubecl-core/src/ir/subcube.rs b/crates/cubecl-core/src/ir/subcube.rs index 8e8e2f0ca..51a79cb8f 100644 --- a/crates/cubecl-core/src/ir/subcube.rs +++ b/crates/cubecl-core/src/ir/subcube.rs @@ -21,33 +21,33 @@ pub enum Subcube { impl Subcube { pub fn out(&self) -> Option { - let val = match self { - Subcube::Elect(init_operator) => init_operator.out, - Subcube::Broadcast(binary_operator) => binary_operator.out, - Subcube::All(unary_operator) - | Subcube::Any(unary_operator) - | Subcube::Sum(unary_operator) - | Subcube::Prod(unary_operator) - | Subcube::Min(unary_operator) - | Subcube::Max(unary_operator) => unary_operator.out, - }; - Some(val) + match self { + Self::Elect(init_operator) => init_operator.out, + Self::Broadcast(binary_operator) => binary_operator.out, + Self::All(unary_operator) + | Self::Any(unary_operator) + | Self::Sum(unary_operator) + | Self::Prod(unary_operator) + | Self::Min(unary_operator) + | Self::Max(unary_operator) => unary_operator.out, + } + .into() } } impl Display for Subcube { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Subcube::Elect(op) => writeln!(f, "{} = subcube_elect()", op.out), - Subcube::All(op) => writeln!(f, "{} = subcube_all({})", op.out, op.input), - Subcube::Any(op) => writeln!(f, "{} = subcube_any({})", op.out, op.input), - Subcube::Broadcast(op) => { + Self::Elect(op) => writeln!(f, "{} = subcube_elect()", op.out), + Self::All(op) => writeln!(f, "{} = subcube_all({})", op.out, op.input), + Self::Any(op) => writeln!(f, "{} = subcube_any({})", op.out, op.input), + Self::Broadcast(op) => { writeln!(f, "{} = subcube_broadcast({}, {})", op.out, op.lhs, op.rhs) } - Subcube::Sum(op) => writeln!(f, "{} = subcube_sum({})", op.out, op.input), - Subcube::Prod(op) => writeln!(f, "{} = subcube_product({})", op.out, op.input), - Subcube::Min(op) => writeln!(f, "{} = subcube_min({})", op.out, op.input), - Subcube::Max(op) => writeln!(f, "{} = subcube_max({})", op.out, op.input), + Self::Sum(op) => writeln!(f, "{} = subcube_sum({})", op.out, op.input), + Self::Prod(op) => writeln!(f, "{} = subcube_product({})", op.out, op.input), + Self::Min(op) => writeln!(f, "{} = subcube_min({})", op.out, op.input), + Self::Max(op) => writeln!(f, "{} = subcube_max({})", op.out, op.input), } } } diff --git a/crates/cubecl-core/src/ir/synchronization.rs b/crates/cubecl-core/src/ir/synchronization.rs index eb0004747..698aa8d1f 100644 --- a/crates/cubecl-core/src/ir/synchronization.rs +++ b/crates/cubecl-core/src/ir/synchronization.rs @@ -14,8 +14,8 @@ pub enum Synchronization { impl Display for Synchronization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Synchronization::SyncUnits => write!(f, "sync_units()"), - Synchronization::SyncStorage => write!(f, "sync_storage()"), + Self::SyncUnits => write!(f, "sync_units()"), + Self::SyncStorage => write!(f, "sync_storage()"), } } } diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index b3ca6c22d..4705d1820 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -91,40 +91,40 @@ impl Variable { /// safe to inline/merge pub fn is_immutable(&self) -> bool { match self { - Variable::GlobalOutputArray { .. } => false, - Variable::Local { .. } => false, - Variable::SharedMemory { .. } => false, - Variable::Matrix { .. } => false, - Variable::Slice { .. } => false, - Variable::LocalArray { .. } => false, - Variable::GlobalInputArray { .. } => false, - Variable::GlobalScalar { .. } => true, - Variable::Versioned { .. } => true, - Variable::LocalBinding { .. } => true, - Variable::ConstantScalar(_) => true, - Variable::ConstantArray { .. } => true, - Variable::Rank => true, - Variable::UnitPos => true, - Variable::UnitPosX => true, - Variable::UnitPosY => true, - Variable::UnitPosZ => true, - Variable::CubePos => true, - Variable::CubePosX => true, - Variable::CubePosY => true, - Variable::CubePosZ => true, - Variable::CubeDim => true, - Variable::CubeDimX => true, - Variable::CubeDimY => true, - Variable::CubeDimZ => true, - Variable::CubeCount => true, - Variable::CubeCountX => true, - Variable::CubeCountY => true, - Variable::CubeCountZ => true, - Variable::SubcubeDim => true, - Variable::AbsolutePos => true, - Variable::AbsolutePosX => true, - Variable::AbsolutePosY => true, - Variable::AbsolutePosZ => true, + Self::GlobalOutputArray { .. } + | Self::Local { .. } + | Self::SharedMemory { .. } + | Self::Matrix { .. } + | Self::Slice { .. } + | Self::LocalArray { .. } + | Self::GlobalInputArray { .. } => false, + Self::GlobalScalar { .. } + | Self::Versioned { .. } + | Self::LocalBinding { .. } + | Self::ConstantScalar(_) + | Self::ConstantArray { .. } + | Self::Rank + | Self::UnitPos + | Self::UnitPosX + | Self::UnitPosY + | Self::UnitPosZ + | Self::CubePos + | Self::CubePosX + | Self::CubePosY + | Self::CubePosZ + | Self::CubeDim + | Self::CubeDimX + | Self::CubeDimY + | Self::CubeDimZ + | Self::CubeCount + | Self::CubeCountX + | Self::CubeCountY + | Self::CubeCountZ + | Self::SubcubeDim + | Self::AbsolutePos + | Self::AbsolutePosX + | Self::AbsolutePosY + | Self::AbsolutePosZ => true, } } @@ -133,13 +133,13 @@ impl Variable { pub fn is_array(&self) -> bool { matches!( self, - Variable::GlobalInputArray { .. } - | Variable::GlobalOutputArray { .. } - | Variable::ConstantArray { .. } - | Variable::SharedMemory { .. } - | Variable::LocalArray { .. } - | Variable::Matrix { .. } - | Variable::Slice { .. } + Self::GlobalInputArray { .. } + | Self::GlobalOutputArray { .. } + | Self::ConstantArray { .. } + | Self::SharedMemory { .. } + | Self::LocalArray { .. } + | Self::Matrix { .. } + | Self::Slice { .. } ) } diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index e1e867926..41b526ddb 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -96,44 +96,41 @@ impl Component for Variable { fn item(&self) -> Item { match self { - Variable::GlobalInputArray(_, e) => *e, - Variable::GlobalOutputArray(_, e) => *e, - Variable::SharedMemory(_, e, _) => *e, - Variable::ConstantArray(_, e, _) => *e, - Variable::Local { item, .. } => *item, - Variable::ConstLocal { item, .. } => *item, - Variable::Slice { item, .. } => *item, - Variable::ConstantScalar(_, e) => Item::scalar(*e), - Variable::GlobalScalar(_, e, _) => Item::scalar(*e), - Variable::IdxGlobal => Item::scalar(Elem::U32), - Variable::ThreadIdxGlobal => Item::scalar(Elem::U32), - Variable::ThreadIdxX => Item::scalar(Elem::U32), - Variable::ThreadIdxY => Item::scalar(Elem::U32), - Variable::ThreadIdxZ => Item::scalar(Elem::U32), - Variable::Rank => Item::scalar(Elem::U32), - Variable::BlockIdxX => Item::scalar(Elem::U32), - Variable::BlockIdxY => Item::scalar(Elem::U32), - Variable::BlockIdxZ => Item::scalar(Elem::U32), - Variable::AbsoluteIdxX => Item::scalar(Elem::U32), - Variable::AbsoluteIdxY => Item::scalar(Elem::U32), - Variable::AbsoluteIdxZ => Item::scalar(Elem::U32), - Variable::BlockDimX => Item::scalar(Elem::U32), - Variable::BlockDimY => Item::scalar(Elem::U32), - Variable::BlockDimZ => Item::scalar(Elem::U32), - Variable::GridDimX => Item::scalar(Elem::U32), - Variable::GridDimY => Item::scalar(Elem::U32), - Variable::GridDimZ => Item::scalar(Elem::U32), - Variable::LocalArray(_, e, _, _) => *e, - Variable::WarpSize => Item::scalar(Elem::U32), - Variable::WmmaFragment { - id: _, - frag, - depth: _, - } => Item::scalar(frag.elem), - Variable::BlockIdxGlobal => Item::scalar(Elem::U32), - Variable::BlockDimGlobal => Item::scalar(Elem::U32), - Variable::GridDimGlobal => Item::scalar(Elem::U32), - Variable::Tmp { item, .. } => *item, + Self::GlobalInputArray(_, e) + | Self::GlobalOutputArray(_, e) + | Self::SharedMemory(_, e, _) + | Self::ConstantArray(_, e, _) => *e, + Self::Local { item, .. } | Self::ConstLocal { item, .. } | Self::Slice { item, .. } => { + *item + } + Self::ConstantScalar(_, e) | Self::GlobalScalar(_, e, _) => Item::scalar(*e), + Self::IdxGlobal + | Self::ThreadIdxGlobal + | Self::ThreadIdxX + | Self::ThreadIdxY + | Self::ThreadIdxZ + | Self::Rank + | Self::BlockIdxX + | Self::BlockIdxY + | Self::BlockIdxZ + | Self::AbsoluteIdxX + | Self::AbsoluteIdxY + | Self::AbsoluteIdxZ + | Self::BlockDimX + | Self::BlockDimY + | Self::BlockDimZ + | Self::GridDimX + | Self::GridDimY + | Self::BlockIdxGlobal + | Self::BlockDimGlobal + | Self::GridDimGlobal + | Self::WarpSize + | Self::GridDimZ => Item::scalar(Elem::U32), + Self::LocalArray(_, e, _, _) => *e, + + Self::WmmaFragment { frag, .. } => Item::scalar(frag.elem), + + Self::Tmp { item, .. } => *item, } } @@ -202,19 +199,19 @@ pub enum Variable { impl Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), - Variable::Local { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), - Variable::ConstLocal { id, depth, .. } => f.write_fmt(format_args!("ssa_{depth}_{id}")), - Variable::Slice { id, item: _, depth } => { + Self::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), + Self::Local { id, depth, .. } => f.write_fmt(format_args!("l_{depth}_{id}")), + Self::ConstLocal { id, depth, .. } => f.write_fmt(format_args!("ssa_{depth}_{id}")), + Self::Slice { id, depth, .. } => { write!(f, "slice_{depth}_{id}") } - Variable::GlobalOutputArray(number, _) => write!(f, "output_{number}"), - Variable::GlobalScalar(number, _, elem) => { + Self::GlobalOutputArray(number, _) => write!(f, "output_{number}"), + Self::GlobalScalar(number, _, elem) => { write!(f, "scalars_{elem}[{number}]") } // We do the conversion in Rust and then render the number to avoid overflow or other // precision related problems. - Variable::ConstantScalar(number, elem) => match number { + Self::ConstantScalar(number, elem) => match number { ConstantScalarValue::Int(val, kind) => match kind { gpu::IntKind::I32 => write!(f, "{elem}({})", *val as i32), gpu::IntKind::I64 => write!(f, "{elem}({})", *val), @@ -234,40 +231,38 @@ impl Display for Variable { } ConstantScalarValue::Bool(val) => write!(f, "{}", val), }, - Variable::SharedMemory(number, _, _) => { + Self::SharedMemory(number, _, _) => { write!(f, "shared_memory_{number}") } - Variable::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")), - Variable::ThreadIdxGlobal => f.write_str("threadIdxGlobal"), - Variable::ThreadIdxX => f.write_str("threadIdx.x"), - Variable::ThreadIdxY => f.write_str("threadIdx.y"), - Variable::ThreadIdxZ => f.write_str("threadIdx.z"), - Variable::Rank => f.write_str("rank"), - Variable::BlockIdxGlobal => f.write_str("blockIdxGlobal"), - Variable::BlockIdxX => f.write_str("blockIdx.x"), - Variable::BlockIdxY => f.write_str("blockIdx.y"), - Variable::BlockIdxZ => f.write_str("blockIdx.z"), - Variable::BlockDimGlobal => f.write_str("blockDimGlobal"), - Variable::BlockDimX => f.write_str("blockDim.x"), - Variable::BlockDimY => f.write_str("blockDim.y"), - Variable::BlockDimZ => f.write_str("blockDim.z"), - Variable::IdxGlobal => f.write_str("idxGlobal"), - Variable::GridDimX => f.write_str("gridDim.x"), - Variable::GridDimY => f.write_str("gridDim.y"), - Variable::GridDimZ => f.write_str("gridDim.z"), - Variable::AbsoluteIdxX => f.write_str("absoluteIdx.x"), - Variable::AbsoluteIdxY => f.write_str("absoluteIdx.y"), - Variable::AbsoluteIdxZ => f.write_str("absoluteIdx.z"), - Variable::LocalArray(id, _item, depth, _size) => { + Self::ConstantArray(number, _, _) => f.write_fmt(format_args!("arrays_{number}")), + Self::ThreadIdxGlobal => f.write_str("threadIdxGlobal"), + Self::ThreadIdxX => f.write_str("threadIdx.x"), + Self::ThreadIdxY => f.write_str("threadIdx.y"), + Self::ThreadIdxZ => f.write_str("threadIdx.z"), + Self::Rank => f.write_str("rank"), + Self::BlockIdxGlobal => f.write_str("blockIdxGlobal"), + Self::BlockIdxX => f.write_str("blockIdx.x"), + Self::BlockIdxY => f.write_str("blockIdx.y"), + Self::BlockIdxZ => f.write_str("blockIdx.z"), + Self::BlockDimGlobal => f.write_str("blockDimGlobal"), + Self::BlockDimX => f.write_str("blockDim.x"), + Self::BlockDimY => f.write_str("blockDim.y"), + Self::BlockDimZ => f.write_str("blockDim.z"), + Self::IdxGlobal => f.write_str("idxGlobal"), + Self::GridDimX => f.write_str("gridDim.x"), + Self::GridDimY => f.write_str("gridDim.y"), + Self::GridDimZ => f.write_str("gridDim.z"), + Self::AbsoluteIdxX => f.write_str("absoluteIdx.x"), + Self::AbsoluteIdxY => f.write_str("absoluteIdx.y"), + Self::AbsoluteIdxZ => f.write_str("absoluteIdx.z"), + Self::LocalArray(id, _item, depth, _size) => { write!(f, "l_arr_{}_{}", id, depth) } - Variable::WarpSize => f.write_str("warpSize"), - Variable::WmmaFragment { - id: index, - frag: _, - depth, + Self::WarpSize => f.write_str("warpSize"), + Self::WmmaFragment { + id: index, depth, .. } => write!(f, "frag_{index}_{depth}"), - Variable::GridDimGlobal => f.write_str("gridDimGlobal"), + Self::GridDimGlobal => f.write_str("gridDimGlobal"), Self::Tmp { id, .. } => write!(f, "_tmp_{id}"), } } @@ -322,42 +317,38 @@ impl Variable { pub fn optimized(&self) -> Self { match self { - Variable::GlobalInputArray(id, item) => { - Variable::GlobalInputArray(*id, item.optimized()) - } - Variable::GlobalOutputArray(id, item) => { - Variable::GlobalOutputArray(*id, item.optimized()) - } - Variable::Local { id, item, depth } => Variable::Local { + Self::GlobalInputArray(id, item) => Self::GlobalInputArray(*id, item.optimized()), + Self::GlobalOutputArray(id, item) => Self::GlobalOutputArray(*id, item.optimized()), + Self::Local { id, item, depth } => Self::Local { id: *id, item: item.optimized(), depth: *depth, }, - Variable::ConstLocal { id, item, depth } => Variable::ConstLocal { + Self::ConstLocal { id, item, depth } => Self::ConstLocal { id: *id, item: item.optimized(), depth: *depth, }, - Variable::Slice { id, item, depth } => Variable::Slice { + Self::Slice { id, item, depth } => Self::Slice { id: *id, item: item.optimized(), depth: *depth, }, - Variable::SharedMemory(id, item, size) => { + Self::SharedMemory(id, item, size) => { let before = item.vectorization; let item = item.optimized(); let after = item.vectorization; let scaling = (before / after) as u32; - Variable::SharedMemory(*id, item, size / scaling) + Self::SharedMemory(*id, item, size / scaling) } - Variable::LocalArray(id, item, vec, size) => { + Self::LocalArray(id, item, vec, size) => { let before = item.vectorization; let item = item.optimized(); let after = item.vectorization; let scaling = (before / after) as u32; - Variable::LocalArray(*id, item.optimized(), *vec, size / scaling) + Self::LocalArray(*id, item.optimized(), *vec, size / scaling) } _ => *self, } @@ -365,40 +356,40 @@ impl Variable { pub fn is_always_scalar(&self) -> bool { match self { - Variable::GlobalScalar(_, _, _) => true, - Variable::ConstantScalar(_, _) => true, - Variable::IdxGlobal => true, - Variable::ThreadIdxGlobal => true, - Variable::ThreadIdxX => true, - Variable::ThreadIdxY => true, - Variable::ThreadIdxZ => true, - Variable::Rank => true, - Variable::GlobalInputArray(_, _) => false, - Variable::GlobalOutputArray(_, _) => false, - Variable::SharedMemory(_, _, _) => false, - Variable::ConstantArray(_, _, _) => false, - Variable::Local { .. } => false, - Variable::ConstLocal { .. } => false, - Variable::Slice { .. } => false, - Variable::BlockIdxX => true, - Variable::BlockIdxY => true, - Variable::BlockIdxZ => true, - Variable::AbsoluteIdxX => true, - Variable::AbsoluteIdxY => true, - Variable::AbsoluteIdxZ => true, - Variable::BlockDimX => true, - Variable::BlockDimY => true, - Variable::BlockDimZ => true, - Variable::GridDimX => true, - Variable::GridDimY => true, - Variable::GridDimZ => true, - Variable::LocalArray(_, _, _, _) => false, - Variable::WarpSize => true, - Variable::WmmaFragment { .. } => false, - Variable::BlockIdxGlobal => true, - Variable::BlockDimGlobal => true, - Variable::GridDimGlobal => true, - Variable::Tmp { .. } => false, + Self::ConstantArray(_, _, _) + | Self::ConstLocal { .. } + | Self::GlobalInputArray(_, _) + | Self::GlobalOutputArray(_, _) + | Self::Local { .. } + | Self::LocalArray(_, _, _, _) + | Self::SharedMemory(_, _, _) + | Self::Slice { .. } + | Self::Tmp { .. } + | Self::WmmaFragment { .. } => false, + Self::AbsoluteIdxX + | Self::AbsoluteIdxY + | Self::AbsoluteIdxZ + | Self::BlockDimGlobal + | Self::BlockDimX + | Self::BlockDimY + | Self::BlockDimZ + | Self::BlockIdxGlobal + | Self::BlockIdxX + | Self::BlockIdxY + | Self::BlockIdxZ + | Self::ConstantScalar(_, _) + | Self::GlobalScalar(_, _, _) + | Self::GridDimGlobal + | Self::GridDimX + | Self::GridDimY + | Self::GridDimZ + | Self::IdxGlobal + | Self::Rank + | Self::ThreadIdxGlobal + | Self::ThreadIdxX + | Self::ThreadIdxY + | Self::ThreadIdxZ + | Self::WarpSize => true, } } @@ -419,7 +410,7 @@ impl FmtLeft for Variable { fn fmt_left(&self) -> String { match self { Self::ConstLocal { item, .. } => format!("const {item} {self}"), - Variable::Tmp { item, .. } => format!("{item} {self}"), + Self::Tmp { item, .. } => format!("{item} {self}"), var => format!("{var}"), } } diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index b4b3f2f64..63fe311cd 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -185,17 +185,17 @@ pub enum Instruction { impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Instruction::Return => f.write_str("return;"), - Instruction::Break => f.write_str("break;"), - Instruction::DeclareVariable { var } => match var { + Self::Return => f.write_str("return;"), + Self::Break => f.write_str("break;"), + Self::DeclareVariable { var } => match var { Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"), _ => { let item = var.item(); writeln!(f, "{item} {var};") } }, - Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Slice { + Self::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out), + Self::Slice { input, start, end, @@ -205,17 +205,17 @@ impl Display for Instruction { writeln!(f, "const uint {out}_length = {end} - {start};")?; writeln!(f, "{item} *{out} = {input} + {start};") } - Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out), - Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::CheckedIndex { len, lhs, rhs, out } => { + Self::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), + Self::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), + Self::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), + Self::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out), + Self::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out), + Self::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), + Self::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), + Self::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), + Self::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out), + Self::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out), + Self::CheckedIndex { len, lhs, rhs, out } => { let item_out = out.item(); if let Elem::Atomic(inner) = item_out.elem { write!(f, "{inner}* {out} = &{lhs}[{rhs}];") @@ -230,8 +230,8 @@ impl Display for Instruction { } } } - Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Copy { + Self::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out), + Self::Copy { input, in_index, out, @@ -239,7 +239,7 @@ impl Display for Instruction { } => { writeln!(f, "{out}[{out_index}] = {input}[{in_index}];") } - Instruction::CopyBulk { + Self::CopyBulk { input, in_index, out, @@ -251,8 +251,8 @@ impl Display for Instruction { } Ok(()) } - Instruction::Assign(it) => Assign::format(f, &it.input, &it.out), - Instruction::RangeLoop { + Self::Assign(it) => Assign::format(f, &it.input, &it.out), + Self::RangeLoop { i, start, end, @@ -279,21 +279,21 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ f.write_str("}\n") } - Instruction::Loop { instructions } => { + Self::Loop { instructions } => { writeln!(f, "while (true) {{")?; for i in instructions { write!(f, "{i}")?; } f.write_str("}\n") } - Instruction::If { cond, instructions } => { + Self::If { cond, instructions } => { writeln!(f, "if ({cond}) {{")?; for i in instructions { write!(f, "{i}")?; } f.write_str("}\n") } - Instruction::IfElse { + Self::IfElse { cond, instructions_if, instructions_else, @@ -308,7 +308,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ } f.write_str("}\n") } - Instruction::Select { + Self::Select { cond, then, or_else, @@ -350,7 +350,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ writeln!(f, "{out} = ({cond}) ? {then} : {or_else};") } } - Instruction::Switch { + Self::Switch { value, instructions_default, instructions_cases, @@ -369,51 +369,51 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ } f.write_str("}\n}\n") } - Instruction::Stride { dim, position, out } => { + Self::Stride { dim, position, out } => { let out = out.fmt_left(); writeln!(f, "{out} = info[({position} * rank_2) + {dim} + 1];") } - Instruction::Shape { dim, position, out } => { + Self::Shape { dim, position, out } => { let out = out.fmt_left(); writeln!(f, "{out} = info[({position} * rank_2) + rank + {dim} + 1];") } - Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Erf(it) => Erf::format(f, &it.input, &it.out), - Instruction::Abs(it) => Abs::format(f, &it.input, &it.out), - Instruction::Exp(it) => Exp::format(f, &it.input, &it.out), - Instruction::Log(it) => Log::format(f, &it.input, &it.out), - Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), - Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), - Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), - Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), - Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), - Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Not(it) => Not::format(f, &it.input, &it.out), - Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Clamp { + Self::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out), + Self::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out), + Self::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out), + Self::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out), + Self::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out), + Self::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out), + Self::Erf(it) => Erf::format(f, &it.input, &it.out), + Self::Abs(it) => Abs::format(f, &it.input, &it.out), + Self::Exp(it) => Exp::format(f, &it.input, &it.out), + Self::Log(it) => Log::format(f, &it.input, &it.out), + Self::Log1p(it) => Log1p::format(f, &it.input, &it.out), + Self::Cos(it) => Cos::format(f, &it.input, &it.out), + Self::Sin(it) => Sin::format(f, &it.input, &it.out), + Self::Tanh(it) => Tanh::format(f, &it.input, &it.out), + Self::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), + Self::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), + Self::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), + Self::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), + Self::Not(it) => Not::format(f, &it.input, &it.out), + Self::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out), + Self::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out), + Self::Clamp { input, min_value, max_value, out, } => Clamp::format(f, input, min_value, max_value, out), - Instruction::SyncThreads => f.write_str("__syncthreads();\n"), - Instruction::ThreadFence => f.write_str("__threadfence();\n"), - Instruction::Round(it) => Round::format(f, &it.input, &it.out), - Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), - Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), - Instruction::SliceLength { input, out } => { + Self::SyncThreads => f.write_str("__syncthreads();\n"), + Self::ThreadFence => f.write_str("__threadfence();\n"), + Self::Round(it) => Round::format(f, &it.input, &it.out), + Self::Ceil(it) => Ceil::format(f, &it.input, &it.out), + Self::Floor(it) => Floor::format(f, &it.input, &it.out), + Self::SliceLength { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {input}_length;") } - Instruction::Length { + Self::Length { input, out, num_inputs, @@ -437,10 +437,10 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ "{out} = info[({offset} * 2 * info[0]) + {index}] / {factor};" ) } - Instruction::Wrap(it) => write!(f, "{it}"), - Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out), - Instruction::Wmma(it) => write!(f, "{it}"), - Instruction::Bitcast(UnaryInstruction { input, out }) => { + Self::Wrap(it) => write!(f, "{it}"), + Self::Fma { a, b, c, out } => Fma::format(f, a, b, c, out), + Self::Wmma(it) => write!(f, "{it}"), + Self::Bitcast(UnaryInstruction { input, out }) => { let out_elem = out.elem(); let out = out.fmt_left(); match (input.elem(), out_elem) { @@ -483,7 +483,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ _ => panic!("Unsupported type for bitcasting"), } } - Instruction::AtomicCAS { + Self::AtomicCAS { input, cmp, val, @@ -492,55 +492,55 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});") } - Instruction::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicSwap(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicExch({lhs}, {rhs});") } - Instruction::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicAdd(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicAdd({lhs}, {rhs});") } - Instruction::AtomicSub(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicSub(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicSub({lhs}, {rhs});") } - Instruction::AtomicMax(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicMax(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicMax({lhs}, {rhs});") } - Instruction::AtomicMin(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicMin(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicMin({lhs}, {rhs});") } - Instruction::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicAnd(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicAnd({lhs}, {rhs});") } - Instruction::AtomicOr(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicOr(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicOr({lhs}, {rhs});") } - Instruction::AtomicXor(BinaryInstruction { lhs, rhs, out }) => { + Self::AtomicXor(BinaryInstruction { lhs, rhs, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicXor({lhs}, {rhs});") } - Instruction::AtomicLoad(UnaryInstruction { input, out }) => { + Self::AtomicLoad(UnaryInstruction { input, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = atomicAdd({input}, 0);") } - Instruction::AtomicStore(UnaryInstruction { input, out }) => { + Self::AtomicStore(UnaryInstruction { input, out }) => { let out = out.fmt_left(); writeln!(f, "atomicExch({out}, {input});") } - Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out), - Instruction::Negate(UnaryInstruction { input, out }) => { + Self::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out), + Self::Negate(UnaryInstruction { input, out }) => { let out = out.fmt_left(); writeln!(f, "{out} = !{input};") } - Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out), - Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out), - Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out), - Instruction::VecInit { inputs, out } => { + Self::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out), + Self::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out), + Self::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out), + Self::VecInit { inputs, out } => { let item = out.item(); let inputs = inputs .iter() diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index cd5707bc4..521961e36 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -61,8 +61,8 @@ pub enum WmmaInstruction { impl Display for FragmentLayout { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - FragmentLayout::ColMajor => f.write_str("nvcuda::wmma::col_major"), - FragmentLayout::RowMajor => f.write_str("nvcuda::wmma::row_major"), + Self::ColMajor => f.write_str("nvcuda::wmma::col_major"), + Self::RowMajor => f.write_str("nvcuda::wmma::row_major"), } } } @@ -70,9 +70,9 @@ impl Display for FragmentLayout { impl Display for FragmentIdent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - FragmentIdent::A => f.write_str("nvcuda::wmma::matrix_a"), - FragmentIdent::B => f.write_str("nvcuda::wmma::matrix_b"), - FragmentIdent::Accumulator => f.write_str("nvcuda::wmma::accumulator"), + Self::A => f.write_str("nvcuda::wmma::matrix_a"), + Self::B => f.write_str("nvcuda::wmma::matrix_b"), + Self::Accumulator => f.write_str("nvcuda::wmma::accumulator"), } } } @@ -97,10 +97,10 @@ impl Display for Fragment { impl Display for WmmaInstruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - WmmaInstruction::Fill { frag, value } => { + Self::Fill { frag, value } => { writeln!(f, "nvcuda::wmma::fill_fragment({frag}, {value});") } - WmmaInstruction::Load { + Self::Load { frag, value, stride, @@ -109,7 +109,7 @@ impl Display for WmmaInstruction { f, "nvcuda::wmma::load_matrix_sync({frag}, {value}, {stride});" ), - WmmaInstruction::Load { + Self::Load { frag, value, stride, @@ -124,7 +124,7 @@ impl Display for WmmaInstruction { "nvcuda::wmma::load_matrix_sync({frag}, {value}, {stride}, {layout});" ) } - WmmaInstruction::Execute { + Self::Execute { frag_a, frag_b, frag_c, @@ -133,7 +133,7 @@ impl Display for WmmaInstruction { f, "nvcuda::wmma::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});" ), - WmmaInstruction::Store { + Self::Store { output, frag, stride, diff --git a/crates/cubecl-cpp/src/shared/warp.rs b/crates/cubecl-cpp/src/shared/warp.rs index a731c3eed..53e81c323 100644 --- a/crates/cubecl-cpp/src/shared/warp.rs +++ b/crates/cubecl-cpp/src/shared/warp.rs @@ -41,9 +41,9 @@ pub enum WarpInstruction { impl Display for WarpInstruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - WarpInstruction::ReduceSum { input, out } => reduce_operator(f, input, out, "+="), - WarpInstruction::ReduceProd { input, out } => reduce_operator(f, input, out, "*="), - WarpInstruction::ReduceMax { input, out } => write!( + Self::ReduceSum { input, out } => reduce_operator(f, input, out, "+="), + Self::ReduceProd { input, out } => reduce_operator(f, input, out, "*="), + Self::ReduceMax { input, out } => write!( f, " {out} = {input}; @@ -54,7 +54,7 @@ for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ }} " ), - WarpInstruction::ReduceMin { input, out } => write!( + Self::ReduceMin { input, out } => write!( f, " {out} = {input}; @@ -65,7 +65,7 @@ for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ }} " ), - WarpInstruction::Elect { out } => write!( + Self::Elect { out } => write!( f, " unsigned int mask = __activemask(); @@ -73,7 +73,7 @@ unsigned int leader = __ffs(mask) - 1; {out} = threadIdx.x % warpSize == leader; " ), - WarpInstruction::All { input, out } => write!( + Self::All { input, out } => write!( f, " {out} = {input}; @@ -82,7 +82,7 @@ unsigned int leader = __ffs(mask) - 1; }} " ), - WarpInstruction::Any { input, out } => write!( + Self::Any { input, out } => write!( f, " {out} = {input}; @@ -91,7 +91,7 @@ unsigned int leader = __ffs(mask) - 1; }} " ), - WarpInstruction::Broadcast { input, id, out } => write!( + Self::Broadcast { input, id, out } => write!( f, " {out} = __shfl_sync(0xFFFFFFFF, {input}, {id}); diff --git a/crates/cubecl-linalg/src/matmul/cmma/config/strategy.rs b/crates/cubecl-linalg/src/matmul/cmma/config/strategy.rs index 795b6b9ce..95ae1bb8c 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/config/strategy.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/config/strategy.rs @@ -31,10 +31,8 @@ impl RasterizationStrategy { let cubes_for_cols = f32::ceil(num_cols as f32 / b_n as f32) as u32; match self { - RasterizationStrategy::RowMajor | RasterizationStrategy::Swizzle => { - (cubes_for_cols, cubes_for_rows) - } - RasterizationStrategy::ColMajor => (cubes_for_rows, cubes_for_cols), + Self::RowMajor | Self::Swizzle => (cubes_for_cols, cubes_for_rows), + Self::ColMajor => (cubes_for_rows, cubes_for_cols), } } } @@ -86,16 +84,16 @@ pub enum MainLoopStrategy { impl MainLoopStrategy { pub(crate) fn get_num_load_planes(&self, num_compute_planes: u32) -> u32 { match self { - MainLoopStrategy::Standard => num_compute_planes, - MainLoopStrategy::Split(num_load_planes) => *num_load_planes, + Self::Standard => num_compute_planes, + Self::Split(num_load_planes) => *num_load_planes, } } pub(crate) fn get_num_planes(&self, num_compute_planes: u32) -> u32 { num_compute_planes + match self { - MainLoopStrategy::Standard => 0, - MainLoopStrategy::Split(num_load) => *num_load, + Self::Standard => 0, + Self::Split(num_load) => *num_load, } } } @@ -145,8 +143,8 @@ pub enum NumComputePlanesStrategy { impl NumComputePlanesStrategy { pub(crate) fn get_num_compute_planes(&self, num_tiles_m: u32, num_tiles_k: u32) -> u32 { match self { - NumComputePlanesStrategy::NumTilesLhs => num_tiles_m * num_tiles_k, - NumComputePlanesStrategy::NumTilesM => num_tiles_m, + Self::NumTilesLhs => num_tiles_m * num_tiles_k, + Self::NumTilesM => num_tiles_m, } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index ec9ff5193..9cd4406e3 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -38,11 +38,7 @@ pub fn matmul_cmma_ref( ) { let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { MatrixLayout::Contiguous => true, - MatrixLayout::MildlyPermuted { - transposed: _, - batch_swap: _, - } => false, - MatrixLayout::HighlyPermuted => false, + MatrixLayout::MildlyPermuted { .. } | MatrixLayout::HighlyPermuted => false, }; let lhs_correct_layout = check_layout(&lhs); diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs index 59356a41a..d05e35e9d 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/launch.rs @@ -39,11 +39,7 @@ pub fn matmul_tiling_2d_ref( "Shared memory limit will be busted. " ); let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { - MatrixLayout::Contiguous => true, - MatrixLayout::MildlyPermuted { - transposed: _, - batch_swap: _, - } => true, + MatrixLayout::Contiguous | MatrixLayout::MildlyPermuted { .. } => true, MatrixLayout::HighlyPermuted => false, }; let lhs_correct_layout = check_layout(&lhs); @@ -91,10 +87,7 @@ fn matmul_tiling_2d_ref_no_check( let check_layout = |strides: &[usize]| match matrix_layout(strides) { MatrixLayout::Contiguous => false, - MatrixLayout::MildlyPermuted { - transposed, - batch_swap: _, - } => transposed, + MatrixLayout::MildlyPermuted { transposed, .. } => transposed, MatrixLayout::HighlyPermuted => { panic!("Can't run on highly permuted tensor") } diff --git a/crates/cubecl-macros/src/expression.rs b/crates/cubecl-macros/src/expression.rs index 847f717f4..c92f74d23 100644 --- a/crates/cubecl-macros/src/expression.rs +++ b/crates/cubecl-macros/src/expression.rs @@ -159,105 +159,111 @@ pub struct Block { impl Expression { pub fn ty(&self) -> Option { match self { - Expression::Binary { ty, .. } => ty.clone(), - Expression::Unary { ty, .. } => ty.clone(), - Expression::Variable(var) => var.ty.clone(), - Expression::Literal { ty, .. } => Some(ty.clone()), - Expression::Assignment { ty, .. } => ty.clone(), - Expression::Verbatim { .. } => None, - Expression::Block(block) => block.ty.clone(), - Expression::FunctionCall { .. } => None, - Expression::Break { .. } => None, - Expression::Cast { to, .. } => Some(to.clone()), - Expression::Continue { .. } => None, - Expression::ForLoop { .. } => None, - Expression::FieldAccess { .. } => None, - Expression::MethodCall { .. } => None, - Expression::Path { .. } => None, - Expression::Range { start, .. } => start.ty(), - Expression::Loop { .. } => None, - Expression::If { then_block, .. } => then_block.ty.clone(), - Expression::Switch { default, .. } => default.ty.clone(), - Expression::Return { expr, .. } => expr.as_ref().and_then(|expr| expr.ty()), - Expression::Array { .. } => None, - Expression::Index { .. } => None, - Expression::Tuple { .. } => None, - Expression::Slice { expr, .. } => expr.ty(), - Expression::ArrayInit { init, .. } => init.ty(), - Expression::VerbatimTerminated { .. } => None, - Expression::Reference { inner } => inner.ty(), - Expression::StructInit { .. } => None, - Expression::Closure { .. } => None, - Expression::Keyword { .. } => None, - Expression::CompilerIntrinsic { .. } => None, - Expression::ConstMatch { .. } => None, + Self::Binary { ty, .. } | Self::Unary { ty, .. } | Self::Assignment { ty, .. } => { + ty.clone() + } + Self::Variable(var) => var.ty.clone(), + Self::Literal { ty, .. } => Some(ty.clone()), + + Self::Block(block) => block.ty.clone(), + + Self::Cast { to, .. } => Some(to.clone()), + + Self::Range { start, .. } => start.ty(), + + Self::If { then_block, .. } => then_block.ty.clone(), + Self::Switch { default, .. } => default.ty.clone(), + Self::Return { expr, .. } => expr.as_ref().and_then(|expr| expr.ty()), + + Self::Slice { expr, .. } => expr.ty(), + Self::ArrayInit { init, .. } => init.ty(), + + Self::Reference { inner } => inner.ty(), + Self::Verbatim { .. } + | Self::FunctionCall { .. } + | Self::Break { .. } + | Self::Array { .. } + | Self::Index { .. } + | Self::Tuple { .. } + | Self::Continue { .. } + | Self::ForLoop { .. } + | Self::FieldAccess { .. } + | Self::MethodCall { .. } + | Self::Path { .. } + | Self::Loop { .. } + | Self::VerbatimTerminated { .. } + | Self::StructInit { .. } + | Self::Closure { .. } + | Self::Keyword { .. } + | Self::CompilerIntrinsic { .. } + | Self::ConstMatch { .. } => None, } } pub fn is_const(&self) -> bool { match self { - Expression::Literal { .. } => true, - Expression::Path { .. } => true, - Expression::Verbatim { .. } => true, - Expression::VerbatimTerminated { .. } => true, - Expression::Variable(var) => var.is_const, - Expression::FieldAccess { base, .. } => base.is_const(), - Expression::Reference { inner } => inner.is_const(), - Expression::Array { elements, .. } => elements.iter().all(|it| it.is_const()), - Expression::Tuple { elements, .. } => elements.iter().all(|it| it.is_const()), - Expression::CompilerIntrinsic { .. } => true, + Self::Literal { .. } + | Self::Path { .. } + | Self::Verbatim { .. } + | Self::VerbatimTerminated { .. } + | Self::CompilerIntrinsic { .. } => true, + Self::Variable(var) => var.is_const, + Self::FieldAccess { base, .. } => base.is_const(), + Self::Reference { inner } => inner.is_const(), + Self::Array { elements, .. } | Self::Tuple { elements, .. } => { + elements.iter().all(|it| it.is_const()) + } + _ => false, } } pub fn as_const(&self, context: &mut Context) -> Option { match self { - Expression::Literal { value, .. } => Some(quote![#value]), - Expression::Verbatim { tokens, .. } => Some(tokens.clone()), - Expression::VerbatimTerminated { tokens, .. } => Some(tokens.clone()), - Expression::Variable(ManagedVar { + Self::Literal { value, .. } => Some(quote![#value]), + Self::Verbatim { tokens, .. } => Some(tokens.clone()), + Self::VerbatimTerminated { tokens, .. } => Some(tokens.clone()), + Self::Variable(ManagedVar { name, is_const: true, .. }) => Some(quote![#name.clone()]), - Expression::Path { path, .. } => Some(quote![#path]), - Expression::Array { elements, .. } => { + Self::Path { path, .. } => Some(quote![#path]), + Self::Array { elements, .. } => { let elements = elements .iter() .map(|it| it.as_const(context)) .collect::>>()?; Some(quote![[#(#elements),*]]) } - Expression::Tuple { elements, .. } => { + Self::Tuple { elements, .. } => { let elements = elements .iter() .map(|it| it.as_const(context)) .collect::>>()?; Some(quote![(#(#elements),*)]) } - Expression::FieldAccess { base, field, .. } => { + Self::FieldAccess { base, field, .. } => { base.as_const(context).map(|base| quote![#base.#field]) } - Expression::Reference { inner } => inner.as_const(context).map(|base| quote![&#base]), - Expression::MethodCall { .. } if self.is_const() => Some(self.to_tokens(context)), + Self::Reference { inner } => inner.as_const(context).map(|base| quote![&#base]), + Self::MethodCall { .. } if self.is_const() => Some(self.to_tokens(context)), _ => None, } } pub fn as_index(&self) -> Option<(&Expression, &Expression)> { match self { - Expression::Index { expr, index, .. } => Some((&**expr, &**index)), + Self::Index { expr, index, .. } => Some((&**expr, &**index)), _ => None, } } pub fn needs_terminator(&self) -> bool { match self { - Expression::If { then_block, .. } => then_block.ret.is_some(), - Expression::Block(block) => block.ret.is_some(), - Expression::ForLoop { .. } => false, - Expression::Loop { .. } => false, - Expression::VerbatimTerminated { .. } => false, + Self::If { then_block, .. } => then_block.ret.is_some(), + Self::Block(block) => block.ret.is_some(), + Self::ForLoop { .. } | Self::Loop { .. } | Self::VerbatimTerminated { .. } => false, _ => true, } } diff --git a/crates/cubecl-macros/src/generate/cube_type/generate.rs b/crates/cubecl-macros/src/generate/cube_type/generate.rs index 6ed6c72c6..a54e9b067 100644 --- a/crates/cubecl-macros/src/generate/cube_type/generate.rs +++ b/crates/cubecl-macros/src/generate/cube_type/generate.rs @@ -5,8 +5,8 @@ use crate::parse::cube_type::CubeType; impl CubeType { pub fn generate(&self, with_launch: bool) -> TokenStream { match self { - CubeType::Enum(data) => data.generate(with_launch), - CubeType::Struct(data) => data.generate(with_launch), + Self::Enum(data) => data.generate(with_launch), + Self::Struct(data) => data.generate(with_launch), } } } diff --git a/crates/cubecl-macros/src/generate/expression.rs b/crates/cubecl-macros/src/generate/expression.rs index ae81a7cef..b7a6d4936 100644 --- a/crates/cubecl-macros/src/generate/expression.rs +++ b/crates/cubecl-macros/src/generate/expression.rs @@ -18,12 +18,12 @@ macro_rules! error { impl Expression { pub fn to_tokens(&self, context: &mut Context) -> TokenStream { match self { - Expression::Binary { + Self::Binary { left, operator, right, .. - } if operator.is_assign() && matches!(**left, Expression::Index { .. }) => { + } if operator.is_assign() && matches!(**left, Self::Index { .. }) => { let elem = frontend_type("ExpandElementTyped"); let frontend_path = frontend_path(); let (array, index) = left.as_index().unwrap(); @@ -46,7 +46,7 @@ impl Expression { } } } - Expression::Binary { + Self::Binary { left, operator, right, @@ -64,12 +64,12 @@ impl Expression { } } } - Expression::Unary { + Self::Unary { input, operator: Operator::Deref, .. } => input.to_tokens(context), - Expression::Unary { + Self::Unary { input, operator, .. } => { let frontend_path = frontend_path(); @@ -82,15 +82,15 @@ impl Expression { } } } - Expression::Keyword { name } => { + Self::Keyword { name } => { quote![#name::expand(context)] } - Expression::Variable(var) if var.is_const => { + Self::Variable(var) if var.is_const => { let name = &var.name; let expand_elem = frontend_type("ExpandElementTyped"); quote![#expand_elem::from_lit(#name)] } - Expression::Variable(var) => { + Self::Variable(var) => { let name = &var.name; if var.try_consume(context) { quote![#name] @@ -99,20 +99,18 @@ impl Expression { } } - Expression::FieldAccess { base, field, .. } => { + Self::FieldAccess { base, field, .. } => { let base = base .as_const(context) .unwrap_or_else(|| base.to_tokens(context)); quote![#base.#field.clone()] } - Expression::Literal { value, .. } => { + Self::Literal { value, .. } => { let expand_elem = frontend_type("ExpandElementTyped"); quote![#expand_elem::from_lit(#value)] } - Expression::Assignment { left, right, .. } - if matches!(**left, Expression::Index { .. }) => - { + Self::Assignment { left, right, .. } if matches!(**left, Self::Index { .. }) => { let (array, index) = left.as_index().unwrap(); let array = array.to_tokens(context); let index = index.to_tokens(context); @@ -127,7 +125,7 @@ impl Expression { } } } - Expression::Assignment { left, right, .. } => { + Self::Assignment { left, right, .. } => { let frontend_path = frontend_path(); let left = left.to_tokens(context); let right = right.to_tokens(context); @@ -139,7 +137,7 @@ impl Expression { } } } - Expression::Index { expr, index } => { + Self::Index { expr, index } => { let expr = expr.to_tokens(context); let index = index.to_tokens(context); let index_fn = frontend_type("index"); @@ -151,7 +149,7 @@ impl Expression { } } } - Expression::FunctionCall { + Self::FunctionCall { func, args, associated_type: None, @@ -166,7 +164,7 @@ impl Expression { } } } - Expression::CompilerIntrinsic { func, args } => { + Self::CompilerIntrinsic { func, args } => { let (args, arg_names) = map_args(args, context); let mut path = func.clone(); let generics = core::mem::replace( @@ -180,7 +178,7 @@ impl Expression { } } } - Expression::FunctionCall { + Self::FunctionCall { args, associated_type: Some((ty_path, func)), .. @@ -195,7 +193,7 @@ impl Expression { } } } - Expression::MethodCall { + Self::MethodCall { receiver, method, generics, @@ -214,19 +212,19 @@ impl Expression { } } } - Expression::Break => { + Self::Break => { let path = frontend_path(); quote![#path::branch::break_expand(context);] } - Expression::Continue(span) => error!(*span, "Continue not supported yet"), - Expression::Return { expr, span, .. } => { + Self::Continue(span) => error!(*span, "Continue not supported yet"), + Self::Return { expr, span, .. } => { if expr.is_some() { error!(*span, "Only void return is supported.") } else { quote![cubecl::frontend::branch::return_expand(context);] } } - Expression::Cast { from, to } => { + Self::Cast { from, to } => { let cast = prelude_type("Cast"); let from = from.to_tokens(context); let to = quote_spanned![to.span()=> <#to as #cast>]; @@ -235,7 +233,7 @@ impl Expression { #to::__expand_cast_from(context, __from) }} } - Expression::ForLoop { + Self::ForLoop { range, unroll, var_name, @@ -261,13 +259,13 @@ impl Expression { } } } - Expression::Loop { block, scope } => { + Self::Loop { block, scope } => { let loop_ty = frontend_type("branch"); let block = context.in_fn_mut(scope, |ctx| block.to_tokens(ctx)); quote![#loop_ty::loop_expand(context, |context| #block);] } - Expression::If { + Self::If { condition, then_block, else_branch, @@ -280,7 +278,7 @@ impl Expression { .map(|it| quote![else #it]); quote![if #as_const #then_block #else_branch] } - Expression::If { + Self::If { condition, then_block, else_branch: Some(else_branch), @@ -296,7 +294,7 @@ impl Expression { } } } - Expression::If { + Self::If { condition, then_block, else_branch: Some(else_branch), @@ -312,7 +310,7 @@ impl Expression { } } } - Expression::If { + Self::If { condition, then_block, .. @@ -327,7 +325,7 @@ impl Expression { } } } - Expression::Switch { + Self::Switch { value, cases, default, @@ -355,8 +353,8 @@ impl Expression { } } } - Expression::Path { path, .. } => quote![#path], - Expression::Range { + Self::Path { path, .. } => quote![#path], + Self::Range { start, end, inclusive, @@ -382,14 +380,14 @@ impl Expression { } } - Expression::Array { span, .. } => { + Self::Array { span, .. } => { if let Some(constant) = self.as_const(context) { constant } else { error!(*span, "Array expressions can't be used at runtime") } } - Expression::Tuple { elements, .. } => { + Self::Tuple { elements, .. } => { if let Some(constant) = self.as_const(context) { constant } else { @@ -398,18 +396,18 @@ impl Expression { } } - Expression::Slice { span, .. } => { + Self::Slice { span, .. } => { error!(*span, "Slice expressions not yet implemented") } - Expression::ArrayInit { init, len } => { + Self::ArrayInit { init, len } => { let init_ty = frontend_type("ArrayInit"); let init = init.to_tokens(context); let len = len.to_tokens(context); quote![#init_ty::new(#len, #init)] } - Expression::VerbatimTerminated { tokens } => tokens.clone(), - Expression::Reference { inner } => { + Self::VerbatimTerminated { tokens } => tokens.clone(), + Self::Reference { inner } => { if let Some(as_const) = inner.as_const(context) { quote![&#as_const] } else { @@ -417,7 +415,7 @@ impl Expression { quote![#inner] } } - Expression::StructInit { path, fields } => { + Self::StructInit { path, fields } => { let cube_type = prelude_type("CubeType"); let fields = init_fields(fields, context); let path_last = path.segments.last().unwrap(); @@ -443,7 +441,7 @@ impl Expression { } } } - Expression::Closure { + Self::Closure { params, body, scope, @@ -452,9 +450,9 @@ impl Expression { let body = context.in_fn_mut(scope, |ctx| body.to_tokens(ctx)); quote![|context, #(#params),*| #body] } - Expression::Verbatim { tokens, .. } => tokens.clone(), - Expression::Block(block) => block.to_tokens(context), - Expression::ConstMatch { const_expr, arms } => { + Self::Verbatim { tokens, .. } => tokens.clone(), + Self::Block(block) => block.to_tokens(context), + Self::ConstMatch { const_expr, arms } => { let arms = arms.iter().map(|arm| arm.to_tokens(context)); quote! { diff --git a/crates/cubecl-macros/src/generate/statement.rs b/crates/cubecl-macros/src/generate/statement.rs index 46a925143..cb2a45d72 100644 --- a/crates/cubecl-macros/src/generate/statement.rs +++ b/crates/cubecl-macros/src/generate/statement.rs @@ -7,7 +7,7 @@ use crate::{expression::Expression, paths::frontend_type, scope::Context, statem impl Statement { pub fn to_tokens(&self, context: &mut Context) -> TokenStream { match self { - Statement::Local { variable, init } => { + Self::Local { variable, init } => { let cube_type = frontend_type("CubeType"); let name = &variable.name; let is_mut = variable.is_mut || init.as_deref().map(is_mut_owned).unwrap_or(false); @@ -51,7 +51,7 @@ impl Statement { quote![let #mutable #name #ty;] } } - Statement::Expression { + Self::Expression { expression, terminated, } => { @@ -63,7 +63,7 @@ impl Statement { quote![#expression #terminator] } } - Statement::Skip => TokenStream::new(), + Self::Skip => TokenStream::new(), } } } diff --git a/crates/cubecl-macros/src/parse/autotune.rs b/crates/cubecl-macros/src/parse/autotune.rs index bfdf7435e..b3dbdce2e 100644 --- a/crates/cubecl-macros/src/parse/autotune.rs +++ b/crates/cubecl-macros/src/parse/autotune.rs @@ -43,8 +43,8 @@ pub enum Anchor { impl Anchor { pub fn max(&self) -> TokenStream { match self { - Anchor::Unlimited => quote![None], - Anchor::Max(value) => quote![Some(#value)], + Self::Unlimited => quote![None], + Self::Max(value) => quote![Some(#value)], } } } diff --git a/crates/cubecl-macros/src/parse/cube_impl.rs b/crates/cubecl-macros/src/parse/cube_impl.rs index 69278a0ae..9736358e9 100644 --- a/crates/cubecl-macros/src/parse/cube_impl.rs +++ b/crates/cubecl-macros/src/parse/cube_impl.rs @@ -64,21 +64,21 @@ impl CubeImplItem { pub fn as_func(&mut self) -> Option<&mut KernelFn> { match self { - CubeImplItem::Fn(func) => Some(func), + Self::Fn(func) => Some(func), _ => None, } } pub fn as_func_expand(&mut self) -> Option<&mut KernelFn> { match self { - CubeImplItem::FnExpand(func) => Some(func), + Self::FnExpand(func) => Some(func), _ => None, } } pub fn as_method_expand(&mut self) -> Option<&mut KernelFn> { match self { - CubeImplItem::MethodExpand(func) => Some(func), + Self::MethodExpand(func) => Some(func), _ => None, } } diff --git a/crates/cubecl-macros/src/parse/cube_trait.rs b/crates/cubecl-macros/src/parse/cube_trait.rs index cbd663559..10ee1275c 100644 --- a/crates/cubecl-macros/src/parse/cube_trait.rs +++ b/crates/cubecl-macros/src/parse/cube_trait.rs @@ -54,8 +54,8 @@ impl CubeTraitItem { pub fn func(&self) -> Option<&KernelSignature> { match self { - CubeTraitItem::Fn(func) => Some(func), - CubeTraitItem::Other => None, + Self::Fn(func) => Some(func), + Self::Other => None, } } } @@ -75,8 +75,8 @@ impl CubeTraitImplItem { pub fn func(&mut self) -> Option<&mut KernelFn> { match self { - CubeTraitImplItem::Fn(func) => Some(func), - CubeTraitImplItem::Other => None, + Self::Fn(func) => Some(func), + Self::Other => None, } } } diff --git a/crates/cubecl-macros/src/parse/kernel.rs b/crates/cubecl-macros/src/parse/kernel.rs index 6a4b62de0..f41c382fa 100644 --- a/crates/cubecl-macros/src/parse/kernel.rs +++ b/crates/cubecl-macros/src/parse/kernel.rs @@ -67,8 +67,7 @@ pub enum KernelReturns { impl KernelReturns { pub fn ty(&self) -> Type { match self { - KernelReturns::ExpandType(ty) => ty.clone(), - KernelReturns::Plain(ty) => ty.clone(), + Self::ExpandType(ty) | Self::Plain(ty) => ty.clone(), } } } diff --git a/crates/cubecl-opt/src/debug.rs b/crates/cubecl-opt/src/debug.rs index 2d45b7316..cb6904e7f 100644 --- a/crates/cubecl-opt/src/debug.rs +++ b/crates/cubecl-opt/src/debug.rs @@ -223,14 +223,14 @@ impl Display for ValueTable { impl Display for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Value::Constant(constant) => write!(f, "{constant}"), - Value::Local(local) => write!(f, "{local}"), - Value::Input(id, _) => write!(f, "input({id})"), - Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"), - Value::ConstArray(id, _, _) => write!(f, "const_array({id})"), - Value::Builtin(builtin) => write!(f, "{builtin:?}"), - Value::Output(id, _) => write!(f, "output({id})"), - Value::Slice(id, depth, _) => write!(f, "slice({id}, {depth})"), + Self::Constant(constant) => write!(f, "{constant}"), + Self::Local(local) => write!(f, "{local}"), + Self::Input(id, _) => write!(f, "input({id})"), + Self::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"), + Self::ConstArray(id, _, _) => write!(f, "const_array({id})"), + Self::Builtin(builtin) => write!(f, "{builtin:?}"), + Self::Output(id, _) => write!(f, "output({id})"), + Self::Slice(id, depth, _) => write!(f, "slice({id}, {depth})"), } } } @@ -247,14 +247,14 @@ impl Display for Local { impl Display for Constant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"), - Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"), - Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0), - Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0), - Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0), - Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0), - Constant::UInt(val) => write!(f, "{val}u32"), - Constant::Bool(val) => write!(f, "{val}"), + Self::Int(val, IntKind::I32) => write!(f, "{val}i32"), + Self::Int(val, IntKind::I64) => write!(f, "{val}i64"), + Self::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0), + Self::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0), + Self::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0), + Self::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0), + Self::UInt(val) => write!(f, "{val}u32"), + Self::Bool(val) => write!(f, "{val}"), } } } @@ -262,11 +262,11 @@ impl Display for Constant { impl Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Expression::Instruction(instruction) => write!(f, "{instruction}"), - Expression::Copy(val, _) => write!(f, "copy({val})"), - Expression::Value(value) => write!(f, "{value}"), - Expression::Volatile(value) => write!(f, "volatile({value})"), - Expression::Phi(entries) => write!( + Self::Instruction(instruction) => write!(f, "{instruction}"), + Self::Copy(val, _) => write!(f, "copy({val})"), + Self::Value(value) => write!(f, "{value}"), + Self::Volatile(value) => write!(f, "volatile({value})"), + Self::Phi(entries) => write!( f, "phi({})", entries diff --git a/crates/cubecl-opt/src/gvn/base.rs b/crates/cubecl-opt/src/gvn/base.rs index ec0122c1a..9581ca74e 100644 --- a/crates/cubecl-opt/src/gvn/base.rs +++ b/crates/cubecl-opt/src/gvn/base.rs @@ -124,24 +124,24 @@ pub enum Expression { impl Expression { pub fn depends_on(&self) -> SmallVec<[u32; 4]> { match self { - Expression::Instruction(instruction) => instruction.args.clone(), - Expression::Copy(val, _) => SmallVec::from_slice(&[*val]), - Expression::Phi(_) | Expression::Volatile(_) | Expression::Value(_) => SmallVec::new(), + Self::Instruction(instruction) => instruction.args.clone(), + Self::Copy(val, _) => SmallVec::from_slice(&[*val]), + Self::Phi(_) | Self::Volatile(_) | Self::Value(_) => SmallVec::new(), } } /// Whether the expression is a trivial copy (which does not need to be hoisted since it's free) pub fn is_simple(&self) -> bool { - matches!(self, Expression::Copy(_, _)) + matches!(self, Self::Copy(_, _)) } pub fn item(&self) -> Item { match self { - Expression::Instruction(instruction) => instruction.item, - Expression::Copy(_, item) => *item, - Expression::Value(value) => value.item(), - Expression::Volatile(value) => value.item(), - Expression::Phi(entries) => entries[0].0.item(), + Self::Instruction(instruction) => instruction.item, + Self::Copy(_, item) => *item, + Self::Value(value) => value.item(), + Self::Volatile(value) => value.item(), + Self::Phi(entries) => entries[0].0.item(), } } } @@ -149,14 +149,14 @@ impl Expression { impl Value { pub fn item(&self) -> Item { match self { - Value::Constant(constant) => constant.item(), - Value::Local(local) => local.item, - Value::Input(_, item) => *item, - Value::Scalar(_, elem) => Item::new(*elem), - Value::ConstArray(_, item, _) => *item, - Value::Builtin(_) => Item::new(Elem::UInt), - Value::Output(_, item) => *item, - Value::Slice(_, _, item) => *item, + Self::Constant(constant) => constant.item(), + Self::Local(local) => local.item, + Self::Scalar(_, elem) => Item::new(*elem), + Self::Builtin(_) => Item::new(Elem::UInt), + Self::Input(_, item) + | Self::ConstArray(_, item, _) + | Self::Output(_, item) + | Self::Slice(_, _, item) => *item, } } } @@ -170,7 +170,7 @@ impl Constant { impl From for Expression { fn from(value: Instruction) -> Self { - Expression::Instruction(value) + Self::Instruction(value) } } diff --git a/crates/cubecl-opt/src/gvn/convert.rs b/crates/cubecl-opt/src/gvn/convert.rs index 2b039539d..e3e6c87e0 100644 --- a/crates/cubecl-opt/src/gvn/convert.rs +++ b/crates/cubecl-opt/src/gvn/convert.rs @@ -12,18 +12,16 @@ use super::{Builtin, Constant, Expression, Local, OpId, Value}; impl Expression { pub fn to_operation(&self, leaders: &HashMap, out: Variable) -> Operation { match self { - Expression::Copy(val, _) => { + Self::Copy(val, _) => { let input = leaders[val].as_var(); Operator::Assign(UnaryOperator { input, out }).into() } - Expression::Value(value) | Expression::Volatile(value) => { - Operator::Assign(UnaryOperator { - input: value.as_var(), - out, - }) - .into() - } - Expression::Instruction(instruction) => { + Self::Value(value) | Self::Volatile(value) => Operator::Assign(UnaryOperator { + input: value.as_var(), + out, + }) + .into(), + Self::Instruction(instruction) => { let args = instruction .args .iter() @@ -310,7 +308,7 @@ impl Expression { .into(), } } - Expression::Phi(_) => unreachable!("Phi can't be made into operation"), + Self::Phi(_) => unreachable!("Phi can't be made into operation"), } } } @@ -318,8 +316,8 @@ impl Expression { impl Value { pub(crate) fn as_var(&self) -> Variable { match self { - Value::Constant(val) => Variable::ConstantScalar((*val).into()), - Value::Local(Local { + Self::Constant(val) => Variable::ConstantScalar((*val).into()), + Self::Local(Local { id, depth, version: 0, @@ -329,7 +327,7 @@ impl Value { item: *item, depth: *depth, }, - Value::Local(Local { + Self::Local(Local { id, depth, version, @@ -340,25 +338,25 @@ impl Value { depth: *depth, version: *version, }, - Value::Input(id, item) => Variable::GlobalInputArray { + Self::Input(id, item) => Variable::GlobalInputArray { id: *id, item: *item, }, - Value::Scalar(id, elem) => Variable::GlobalScalar { + Self::Scalar(id, elem) => Variable::GlobalScalar { id: *id, elem: *elem, }, - Value::ConstArray(id, item, len) => Variable::ConstantArray { + Self::ConstArray(id, item, len) => Variable::ConstantArray { id: *id, item: *item, length: *len, }, - Value::Builtin(builtin) => builtin.as_var(), - Value::Output(id, item) => Variable::GlobalOutputArray { + Self::Builtin(builtin) => builtin.as_var(), + Self::Output(id, item) => Variable::GlobalOutputArray { id: *id, item: *item, }, - Value::Slice(id, depth, item) => Variable::Slice { + Self::Slice(id, depth, item) => Variable::Slice { id: *id, item: *item, depth: *depth, @@ -370,28 +368,28 @@ impl Value { impl Builtin { pub fn as_var(&self) -> Variable { match self { - Builtin::Rank => Variable::Rank, - Builtin::UnitPos => Variable::UnitPos, - Builtin::UnitPosX => Variable::UnitPosX, - Builtin::UnitPosY => Variable::UnitPosY, - Builtin::UnitPosZ => Variable::UnitPosZ, - Builtin::CubePos => Variable::CubePos, - Builtin::CubePosX => Variable::CubePosX, - Builtin::CubePosY => Variable::CubePosY, - Builtin::CubePosZ => Variable::CubePosZ, - Builtin::CubeDim => Variable::CubeDim, - Builtin::CubeDimX => Variable::CubeDimX, - Builtin::CubeDimY => Variable::CubeDimY, - Builtin::CubeDimZ => Variable::CubeDimZ, - Builtin::CubeCount => Variable::CubeCount, - Builtin::CubeCountX => Variable::CubeCountX, - Builtin::CubeCountY => Variable::CubeCountY, - Builtin::CubeCountZ => Variable::CubeCountZ, - Builtin::SubcubeDim => Variable::SubcubeDim, - Builtin::AbsolutePos => Variable::AbsolutePos, - Builtin::AbsolutePosX => Variable::AbsolutePosX, - Builtin::AbsolutePosY => Variable::AbsolutePosY, - Builtin::AbsolutePosZ => Variable::AbsolutePosZ, + Self::Rank => Variable::Rank, + Self::UnitPos => Variable::UnitPos, + Self::UnitPosX => Variable::UnitPosX, + Self::UnitPosY => Variable::UnitPosY, + Self::UnitPosZ => Variable::UnitPosZ, + Self::CubePos => Variable::CubePos, + Self::CubePosX => Variable::CubePosX, + Self::CubePosY => Variable::CubePosY, + Self::CubePosZ => Variable::CubePosZ, + Self::CubeDim => Variable::CubeDim, + Self::CubeDimX => Variable::CubeDimX, + Self::CubeDimY => Variable::CubeDimY, + Self::CubeDimZ => Variable::CubeDimZ, + Self::CubeCount => Variable::CubeCount, + Self::CubeCountX => Variable::CubeCountX, + Self::CubeCountY => Variable::CubeCountY, + Self::CubeCountZ => Variable::CubeCountZ, + Self::SubcubeDim => Variable::SubcubeDim, + Self::AbsolutePos => Variable::AbsolutePos, + Self::AbsolutePosX => Variable::AbsolutePosX, + Self::AbsolutePosY => Variable::AbsolutePosY, + Self::AbsolutePosZ => Variable::AbsolutePosZ, } } } diff --git a/crates/cubecl-runtime/src/debug.rs b/crates/cubecl-runtime/src/debug.rs index 6ce373bcc..e41e77637 100644 --- a/crates/cubecl-runtime/src/debug.rs +++ b/crates/cubecl-runtime/src/debug.rs @@ -323,10 +323,10 @@ impl DebugLoggerKind { fn profile_level(&self) -> Option { let option = match self { #[cfg(feature = "std")] - DebugLoggerKind::File(_, option) => option, + Self::File(_, option) => option, #[cfg(feature = "std")] - DebugLoggerKind::Stdout(option) => option, - DebugLoggerKind::None => { + Self::Stdout(option) => option, + Self::None => { return None; } }; @@ -355,7 +355,7 @@ impl DebugLoggerKind { { match self { #[cfg(feature = "std")] - DebugLoggerKind::File(file, option) => { + Self::File(file, option) => { match option { DebugOptions::Debug | DebugOptions::All(_) => { file.log(&arg); @@ -365,7 +365,7 @@ impl DebugLoggerKind { arg } #[cfg(feature = "std")] - DebugLoggerKind::Stdout(option) => { + Self::Stdout(option) => { match option { DebugOptions::Debug | DebugOptions::All(_) => { println!("{arg}"); @@ -374,7 +374,7 @@ impl DebugLoggerKind { }; arg } - DebugLoggerKind::None => arg, + Self::None => arg, } } } diff --git a/crates/cubecl-runtime/src/memory_management/memory_manage.rs b/crates/cubecl-runtime/src/memory_management/memory_manage.rs index d2794825c..55d976351 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_manage.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_manage.rs @@ -44,8 +44,8 @@ const MB: usize = 1024 * 1024; impl MemoryPool for DynamicPool { fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> { match self { - DynamicPool::Sliced(m) => m.get(binding), - DynamicPool::Exclusive(m) => m.get(binding), + Self::Sliced(m) => m.get(binding), + Self::Exclusive(m) => m.get(binding), } } @@ -56,35 +56,35 @@ impl MemoryPool for DynamicPool { locked: Option<&MemoryLock>, ) -> SliceHandle { match self { - DynamicPool::Sliced(m) => m.reserve(storage, size, locked), - DynamicPool::Exclusive(m) => m.reserve(storage, size, locked), + Self::Sliced(m) => m.reserve(storage, size, locked), + Self::Exclusive(m) => m.reserve(storage, size, locked), } } fn alloc(&mut self, storage: &mut Storage, size: u64) -> SliceHandle { match self { - DynamicPool::Sliced(m) => m.alloc(storage, size), - DynamicPool::Exclusive(m) => m.alloc(storage, size), + Self::Sliced(m) => m.alloc(storage, size), + Self::Exclusive(m) => m.alloc(storage, size), } } fn get_memory_usage(&self) -> MemoryUsage { match self { - DynamicPool::Sliced(m) => m.get_memory_usage(), - DynamicPool::Exclusive(m) => m.get_memory_usage(), + Self::Sliced(m) => m.get_memory_usage(), + Self::Exclusive(m) => m.get_memory_usage(), } } fn max_alloc_size(&self) -> u64 { match self { - DynamicPool::Sliced(m) => m.max_alloc_size(), - DynamicPool::Exclusive(m) => m.max_alloc_size(), + Self::Sliced(m) => m.max_alloc_size(), + Self::Exclusive(m) => m.max_alloc_size(), } } fn cleanup(&mut self, storage: &mut Storage, alloc_nr: u64) { match self { - DynamicPool::Sliced(m) => m.cleanup(storage, alloc_nr), - DynamicPool::Exclusive(m) => m.cleanup(storage, alloc_nr), + Self::Sliced(m) => m.cleanup(storage, alloc_nr), + Self::Exclusive(m) => m.cleanup(storage, alloc_nr), } } } diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 281c05e44..d061883ef 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -170,8 +170,8 @@ pub enum CubeCount { impl Debug for CubeCount { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")), - CubeCount::Dynamic(_) => f.write_str("binding"), + Self::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")), + Self::Dynamic(_) => f.write_str("binding"), } } } diff --git a/crates/cubecl-spirv/src/item.rs b/crates/cubecl-spirv/src/item.rs index 444707ed2..9c691abbc 100644 --- a/crates/cubecl-spirv/src/item.rs +++ b/crates/cubecl-spirv/src/item.rs @@ -82,25 +82,23 @@ impl Item { pub fn size(&self) -> u32 { match self { - Item::Scalar(elem) => elem.size(), - Item::Vector(elem, factor) => elem.size() * *factor, - Item::Array(item, len) => item.size() * *len, - Item::RuntimeArray(item) => item.size(), - Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(), - Item::Pointer(_, item) => item.size(), - Item::CoopMatrix { ty, .. } => ty.size(), + Self::Scalar(elem) => elem.size(), + Self::Vector(elem, factor) => elem.size() * *factor, + Self::Array(item, len) => item.size() * *len, + Self::RuntimeArray(item) => item.size(), + Self::Struct(vec) => vec.iter().map(|it| it.size()).sum(), + Self::Pointer(_, item) => item.size(), + Self::CoopMatrix { ty, .. } => ty.size(), } } pub fn elem(&self) -> Elem { match self { - Item::Scalar(elem) => *elem, - Item::Vector(elem, _) => *elem, - Item::Array(item, _) => item.elem(), - Item::RuntimeArray(item) => item.elem(), - Item::Struct(_) => Elem::Void, - Item::Pointer(_, item) => item.elem(), - Item::CoopMatrix { ty, .. } => *ty, + Self::Scalar(elem) | Self::Vector(elem, _) => *elem, + Self::Array(item, _) | Self::Pointer(_, item) | Self::RuntimeArray(item) => item.elem(), + Self::Struct(_) => Elem::Void, + + Self::CoopMatrix { ty, .. } => *ty, } } @@ -109,22 +107,22 @@ impl Item { b.get_or_insert_const(value, self.clone(), |b| { let ty = self.id(b); match self { - Item::Scalar(_) => scalar, - Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)), - Item::Array(item, len) => { + Self::Scalar(_) => scalar, + Self::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)), + Self::Array(item, len) => { let elem = item.constant(b, value); b.constant_composite(ty, (0..*len).map(|_| elem)) } - Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"), - Item::Struct(elems) => { + Self::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"), + Self::Struct(elems) => { let items = elems .iter() .map(|item| item.constant(b, value)) .collect::>(); b.constant_composite(ty, items) } - Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"), - Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"), + Self::Pointer(_, _) => unimplemented!("Can't create constant pointer"), + Self::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"), } }) } @@ -273,10 +271,10 @@ pub enum Elem { impl Elem { pub fn id(&self, b: &mut SpirvCompiler) -> Word { let id = match self { - Elem::Void => b.type_void(), - Elem::Bool => b.type_bool(), - Elem::Int(width, _) => b.type_int(*width, 0), - Elem::Float(width) => b.type_float(*width), + Self::Void => b.type_void(), + Self::Bool => b.type_bool(), + Self::Int(width, _) => b.type_int(*width, 0), + Self::Float(width) => b.type_float(*width), }; if b.debug && !b.state.debug_types.contains(&id) { b.debug_name(id, format!("{self}")); @@ -287,10 +285,10 @@ impl Elem { pub fn size(&self) -> u32 { match self { - Elem::Void => 0, - Elem::Bool => 1, - Elem::Int(size, _) => *size / 8, - Elem::Float(size) => *size / 8, + Self::Void => 0, + Self::Bool => 1, + Self::Int(size, _) => *size / 8, + Self::Float(size) => *size / 8, } } @@ -298,13 +296,13 @@ impl Elem { b.get_or_insert_const(value, Item::Scalar(*self), |b| { let ty = self.id(b); match self { - Elem::Void => unreachable!(), - Elem::Bool if value.as_u64() == 1 => b.constant_true(ty), - Elem::Bool => b.constant_false(ty), - Elem::Int(64, _) => b.constant_bit64(ty, value.as_u64()), - Elem::Int(_, _) => b.constant_bit32(ty, value.as_u32()), - Elem::Float(64) => b.constant_bit64(ty, value.as_u64()), - Elem::Float(_) => b.constant_bit32(ty, value.as_u32()), + Self::Void => unreachable!(), + Self::Bool if value.as_u64() == 1 => b.constant_true(ty), + Self::Bool => b.constant_false(ty), + Self::Int(64, _) => b.constant_bit64(ty, value.as_u64()), + Self::Int(_, _) => b.constant_bit32(ty, value.as_u32()), + Self::Float(64) => b.constant_bit64(ty, value.as_u64()), + Self::Float(_) => b.constant_bit32(ty, value.as_u32()), } }) } @@ -501,19 +499,19 @@ impl SpirvCompiler { impl std::fmt::Display for Item { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Item::Scalar(elem) => write!(f, "{elem}"), - Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"), - Item::Array(item, len) => write!(f, "array<{item}, {len}>"), - Item::RuntimeArray(item) => write!(f, "array<{item}>"), - Item::Struct(members) => { + Self::Scalar(elem) => write!(f, "{elem}"), + Self::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"), + Self::Array(item, len) => write!(f, "array<{item}, {len}>"), + Self::RuntimeArray(item) => write!(f, "array<{item}>"), + Self::Struct(members) => { write!(f, "struct<")?; for item in members { write!(f, "{item}")?; } f.write_str(">") } - Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"), - Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"), + Self::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"), + Self::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"), } } } diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 227d72183..2f5cadf4a 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -83,15 +83,15 @@ pub enum ConstVal { impl ConstVal { pub fn as_u64(&self) -> u64 { match self { - ConstVal::Bit32(val) => *val as u64, - ConstVal::Bit64(val) => *val, + Self::Bit32(val) => *val as u64, + Self::Bit64(val) => *val, } } pub fn as_u32(&self) -> u32 { match self { - ConstVal::Bit32(val) => *val, - ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"), + Self::Bit32(val) => *val, + Self::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"), } } } @@ -101,23 +101,23 @@ impl From for ConstVal { unsafe { match value { ConstantScalarValue::Int(val, IntKind::I32) => { - ConstVal::Bit32(transmute::(val as i32)) + Self::Bit32(transmute::(val as i32)) } ConstantScalarValue::Int(val, IntKind::I64) => { - ConstVal::Bit64(transmute::(val)) + Self::Bit64(transmute::(val)) } - ConstantScalarValue::Float(val, FloatKind::F64) => ConstVal::Bit64(val.to_bits()), + ConstantScalarValue::Float(val, FloatKind::F64) => Self::Bit64(val.to_bits()), ConstantScalarValue::Float(val, FloatKind::F32) => { - ConstVal::Bit32((val as f32).to_bits()) + Self::Bit32((val as f32).to_bits()) } ConstantScalarValue::Float(val, FloatKind::F16) => { - ConstVal::Bit32((val as f32).to_bits()) + Self::Bit32((val as f32).to_bits()) } ConstantScalarValue::Float(_, FloatKind::BF16) => { panic!("bf16 not supported in SPIR-V") } - ConstantScalarValue::UInt(val) => ConstVal::Bit32(val as u32), - ConstantScalarValue::Bool(val) => ConstVal::Bit32(val as u32), + ConstantScalarValue::UInt(val) => Self::Bit32(val as u32), + ConstantScalarValue::Bool(val) => Self::Bit32(val as u32), } } } @@ -125,89 +125,89 @@ impl From for ConstVal { impl From for ConstVal { fn from(value: u32) -> Self { - ConstVal::Bit32(value) + Self::Bit32(value) } } impl From for ConstVal { fn from(value: f32) -> Self { - ConstVal::Bit32(value.to_bits()) + Self::Bit32(value.to_bits()) } } impl Variable { pub fn id(&self, b: &mut SpirvCompiler) -> Word { match self { - Variable::GlobalInputArray(id, _) => *id, - Variable::GlobalOutputArray(id, _) => *id, - Variable::GlobalScalar(id, _) => *id, - Variable::ConstantScalar(id, _, _) => *id, - Variable::Local { id, .. } => *id, - Variable::Versioned { id, .. } => b.get_versioned(*id), - Variable::LocalBinding { id, .. } => b.get_binding(*id), - Variable::Raw(id, _) => *id, - Variable::Named { id, .. } => *id, - Variable::Slice { ptr, .. } => ptr.id(b), - Variable::SharedMemory(id, _, _) => *id, - Variable::ConstantArray(id, _, _) => *id, - Variable::LocalArray(id, _, _) => *id, - Variable::CoopMatrix(_, _, _) => unimplemented!("Can't get ID from matrix var"), - Variable::SubgroupSize(id) => *id, - Variable::Id(id) => *id, - Variable::LocalInvocationIndex(id) => *id, - Variable::LocalInvocationIdX(id) => *id, - Variable::LocalInvocationIdY(id) => *id, - Variable::LocalInvocationIdZ(id) => *id, - Variable::Rank(id) => *id, - Variable::WorkgroupId(id) => *id, - Variable::WorkgroupIdX(id) => *id, - Variable::WorkgroupIdY(id) => *id, - Variable::WorkgroupIdZ(id) => *id, - Variable::GlobalInvocationIndex(id) => *id, - Variable::GlobalInvocationIdX(id) => *id, - Variable::GlobalInvocationIdY(id) => *id, - Variable::GlobalInvocationIdZ(id) => *id, - Variable::WorkgroupSize(id) => *id, - Variable::WorkgroupSizeX(id) => *id, - Variable::WorkgroupSizeY(id) => *id, - Variable::WorkgroupSizeZ(id) => *id, - Variable::NumWorkgroups(id) => *id, - Variable::NumWorkgroupsX(id) => *id, - Variable::NumWorkgroupsY(id) => *id, - Variable::NumWorkgroupsZ(id) => *id, + Self::Versioned { id, .. } => b.get_versioned(*id), + Self::LocalBinding { id, .. } => b.get_binding(*id), + Self::Slice { ptr, .. } => ptr.id(b), + Self::Raw(id, _) + | Self::Named { id, .. } + | Self::GlobalInputArray(id, _) + | Self::GlobalOutputArray(id, _) + | Self::GlobalScalar(id, _) + | Self::ConstantScalar(id, _, _) + | Self::Local { id, .. } + | Self::SharedMemory(id, _, _) + | Self::ConstantArray(id, _, _) + | Self::LocalArray(id, _, _) + | Self::SubgroupSize(id) + | Self::Id(id) + | Self::LocalInvocationIndex(id) + | Self::LocalInvocationIdX(id) + | Self::LocalInvocationIdY(id) + | Self::LocalInvocationIdZ(id) + | Self::Rank(id) + | Self::WorkgroupId(id) + | Self::WorkgroupIdX(id) + | Self::WorkgroupIdY(id) + | Self::WorkgroupIdZ(id) + | Self::GlobalInvocationIndex(id) + | Self::GlobalInvocationIdX(id) + | Self::GlobalInvocationIdY(id) + | Self::GlobalInvocationIdZ(id) + | Self::WorkgroupSize(id) + | Self::WorkgroupSizeX(id) + | Self::WorkgroupSizeY(id) + | Self::WorkgroupSizeZ(id) + | Self::NumWorkgroups(id) + | Self::NumWorkgroupsX(id) + | Self::NumWorkgroupsY(id) + | Self::NumWorkgroupsZ(id) => *id, + Self::CoopMatrix(_, _, _) => unimplemented!("Can't get ID from matrix var"), } } pub fn item(&self) -> Item { match self { - Variable::GlobalInputArray(_, item) => item.clone(), - Variable::GlobalOutputArray(_, item) => item.clone(), - Variable::GlobalScalar(_, elem) => Item::Scalar(*elem), - Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem), - Variable::Local { item, .. } => item.clone(), - Variable::Versioned { item, .. } => item.clone(), - Variable::LocalBinding { item, .. } => item.clone(), - Variable::Named { item, .. } => item.clone(), - Variable::Slice { item, .. } => item.clone(), - Variable::SharedMemory(_, item, _) => item.clone(), - Variable::ConstantArray(_, item, _) => item.clone(), - Variable::LocalArray(_, item, _) => item.clone(), - Variable::CoopMatrix(_, _, elem) => Item::Scalar(*elem), + Self::GlobalScalar(_, elem) + | Self::ConstantScalar(_, _, elem) + | Self::CoopMatrix(_, _, elem) => Item::Scalar(*elem), + Self::GlobalInputArray(_, item) + | Self::GlobalOutputArray(_, item) + | Self::Local { item, .. } + | Self::Versioned { item, .. } + | Self::LocalBinding { item, .. } + | Self::Named { item, .. } + | Self::Slice { item, .. } + | Self::SharedMemory(_, item, _) + | Self::ConstantArray(_, item, _) + | Self::LocalArray(_, item, _) => item.clone(), _ => Item::Scalar(Elem::Int(32, false)), // builtin } } pub fn indexed_item(&self) -> Item { match self { - Variable::LocalBinding { + Self::LocalBinding { item: Item::Vector(elem, _), .. - } => Item::Scalar(*elem), - Variable::Local { + } + | Self::Local { item: Item::Vector(elem, _), .. - } => Item::Scalar(*elem), - Variable::Versioned { + } + | Self::Versioned { item: Item::Vector(elem, _), .. } => Item::Scalar(*elem), @@ -222,31 +222,31 @@ impl Variable { pub fn has_len(&self) -> bool { matches!( self, - Variable::GlobalInputArray(_, _) - | Variable::GlobalOutputArray(_, _) - | Variable::Named { + Self::GlobalInputArray(_, _) + | Self::GlobalOutputArray(_, _) + | Self::Named { is_array: false, .. } - | Variable::Slice { .. } - | Variable::SharedMemory(_, _, _) - | Variable::ConstantArray(_, _, _) - | Variable::LocalArray(_, _, _) + | Self::Slice { .. } + | Self::SharedMemory(_, _, _) + | Self::ConstantArray(_, _, _) + | Self::LocalArray(_, _, _) ) } pub fn as_const(&self) -> Option { - match self { - Self::ConstantScalar(_, val, _) => Some(*val), - _ => None, - } + let Self::ConstantScalar(_, val, _) = self else { + return None; + }; + Some(*val) } pub fn as_binding(&self) -> Option<(u16, u8)> { - match self { - Self::LocalBinding { id, .. } => Some(*id), - _ => None, - } + let Self::LocalBinding { id, .. } = self else { + return None; + }; + Some(*id) } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs index d59964dbd..414076b77 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs @@ -84,42 +84,34 @@ pub struct IndexedVariable { impl Variable { pub fn is_always_scalar(&self) -> bool { - match self { - Variable::GlobalScalar(_, _, _) => true, - Variable::ConstantScalar(_, _) => true, - Variable::LocalScalar { .. } => true, - Variable::Id => true, - Variable::LocalInvocationIndex => true, - Variable::LocalInvocationIdX => true, - Variable::LocalInvocationIdY => true, - Variable::LocalInvocationIdZ => true, - Variable::Rank => true, - Variable::GlobalInputArray(_, _) => false, - Variable::GlobalOutputArray(_, _) => false, - Variable::SharedMemory(_, _, _) => false, - Variable::ConstantArray(_, _, _) => false, - Variable::LocalArray(_, _, _, _) => false, - Variable::Local { .. } => false, - Variable::LocalBinding { .. } => false, - Variable::Named { .. } => false, - Variable::Slice { .. } => false, - Variable::WorkgroupIdX => true, - Variable::WorkgroupIdY => true, - Variable::WorkgroupIdZ => true, - Variable::GlobalInvocationIdX => true, - Variable::GlobalInvocationIdY => true, - Variable::GlobalInvocationIdZ => true, - Variable::WorkgroupSizeX => true, - Variable::WorkgroupSizeY => true, - Variable::WorkgroupSizeZ => true, - Variable::NumWorkgroupsX => true, - Variable::NumWorkgroupsY => true, - Variable::NumWorkgroupsZ => true, - Variable::WorkgroupId => true, - Variable::WorkgroupSize => true, - Variable::NumWorkgroups => true, - Variable::SubgroupSize => true, - } + matches!( + self, + Self::GlobalScalar(_, _, _) + | Self::ConstantScalar(_, _) + | Self::LocalScalar { .. } + | Self::Id + | Self::LocalInvocationIndex + | Self::LocalInvocationIdX + | Self::LocalInvocationIdY + | Self::LocalInvocationIdZ + | Self::Rank + | Self::WorkgroupIdX + | Self::WorkgroupIdY + | Self::WorkgroupIdZ + | Self::GlobalInvocationIdX + | Self::GlobalInvocationIdY + | Self::GlobalInvocationIdZ + | Self::WorkgroupSizeX + | Self::WorkgroupSizeY + | Self::WorkgroupSizeZ + | Self::NumWorkgroupsX + | Self::NumWorkgroupsY + | Self::NumWorkgroupsZ + | Self::WorkgroupId + | Self::WorkgroupSize + | Self::NumWorkgroups + | Self::SubgroupSize + ) } pub fn index(&self, index: usize) -> IndexedVariable { IndexedVariable { @@ -129,55 +121,53 @@ impl Variable { } pub fn is_atomic(&self) -> bool { match self { - Variable::GlobalInputArray(_, item) => item.elem().is_atomic(), - Variable::GlobalOutputArray(_, item) => item.elem().is_atomic(), - Variable::GlobalScalar(_, elem, _) => elem.is_atomic(), - Variable::Local { item, .. } => item.elem().is_atomic(), - Variable::Named { item, .. } => item.elem().is_atomic(), - Variable::Slice { item, .. } => item.elem().is_atomic(), - Variable::LocalScalar { elem, .. } => elem.is_atomic(), - Variable::SharedMemory(_, item, _) => item.elem().is_atomic(), - Variable::LocalArray(_, item, _, _) => item.elem().is_atomic(), + Self::LocalScalar { elem, .. } | Self::GlobalScalar(_, elem, _) => elem.is_atomic(), + Self::GlobalInputArray(_, item) + | Self::GlobalOutputArray(_, item) + | Self::Local { item, .. } + | Self::Named { item, .. } + | Self::Slice { item, .. } + | Self::SharedMemory(_, item, _) + | Self::LocalArray(_, item, _, _) => item.elem().is_atomic(), _ => false, } } pub fn item(&self) -> Item { match self { - Self::GlobalInputArray(_, e) => *e, - Self::GlobalOutputArray(_, e) => *e, - Self::SharedMemory(_, e, _) => *e, - Self::ConstantArray(_, e, _) => *e, - Self::LocalArray(_, e, _, _) => *e, - Self::Local { item, .. } => *item, - Self::LocalBinding { item, .. } => *item, - Self::Slice { item, .. } => *item, - Self::Named { item, .. } => *item, - Self::ConstantScalar(_, e) => Item::Scalar(*e), - Self::GlobalScalar(_, e, _) => Item::Scalar(*e), - Self::Id => Item::Scalar(Elem::U32), - Self::LocalInvocationIndex => Item::Scalar(Elem::U32), - Self::LocalInvocationIdX => Item::Scalar(Elem::U32), - Self::LocalInvocationIdY => Item::Scalar(Elem::U32), - Self::LocalInvocationIdZ => Item::Scalar(Elem::U32), - Self::Rank => Item::Scalar(Elem::U32), + Self::GlobalInputArray(_, e) + | Self::GlobalOutputArray(_, e) + | Self::SharedMemory(_, e, _) + | Self::ConstantArray(_, e, _) + | Self::LocalArray(_, e, _, _) => *e, + Self::Local { item, .. } + | Self::LocalBinding { item, .. } + | Self::Slice { item, .. } + | Self::Named { item, .. } => *item, + Self::ConstantScalar(_, e) | Self::GlobalScalar(_, e, _) => Item::Scalar(*e), Self::LocalScalar { elem, .. } => Item::Scalar(*elem), - Self::WorkgroupId => Item::Scalar(Elem::U32), - Self::WorkgroupIdX => Item::Scalar(Elem::U32), - Self::WorkgroupIdY => Item::Scalar(Elem::U32), - Self::WorkgroupIdZ => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdX => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdY => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdZ => Item::Scalar(Elem::U32), - Self::WorkgroupSize => Item::Scalar(Elem::U32), - Self::WorkgroupSizeX => Item::Scalar(Elem::U32), - Self::WorkgroupSizeY => Item::Scalar(Elem::U32), - Self::WorkgroupSizeZ => Item::Scalar(Elem::U32), - Self::NumWorkgroups => Item::Scalar(Elem::U32), - Self::NumWorkgroupsX => Item::Scalar(Elem::U32), - Self::NumWorkgroupsY => Item::Scalar(Elem::U32), - Self::NumWorkgroupsZ => Item::Scalar(Elem::U32), - Self::SubgroupSize => Item::Scalar(Elem::U32), + Self::Id + | Self::LocalInvocationIndex + | Self::LocalInvocationIdX + | Self::LocalInvocationIdY + | Self::LocalInvocationIdZ + | Self::Rank + | Self::WorkgroupId + | Self::WorkgroupIdX + | Self::WorkgroupIdY + | Self::WorkgroupIdZ + | Self::GlobalInvocationIdX + | Self::GlobalInvocationIdY + | Self::GlobalInvocationIdZ + | Self::WorkgroupSize + | Self::WorkgroupSizeX + | Self::WorkgroupSizeY + | Self::WorkgroupSizeZ + | Self::NumWorkgroups + | Self::NumWorkgroupsX + | Self::NumWorkgroupsY + | Self::NumWorkgroupsZ + | Self::SubgroupSize => Item::Scalar(Elem::U32), } } pub fn elem(&self) -> Elem { @@ -196,10 +186,7 @@ impl Variable { impl Item { pub fn elem(&self) -> &Elem { match self { - Item::Vec4(e) => e, - Item::Vec3(e) => e, - Item::Vec2(e) => e, - Item::Scalar(e) => e, + Item::Vec4(e) | Item::Vec3(e) | Item::Vec2(e) | Item::Scalar(e) => e, } } @@ -254,10 +241,10 @@ impl Display for Elem { impl Display for Item { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Item::Vec4(elem) => write!(f, "vec4<{elem}>"), - Item::Vec3(elem) => write!(f, "vec3<{elem}>"), - Item::Vec2(elem) => write!(f, "vec2<{elem}>"), - Item::Scalar(elem) => write!(f, "{elem}"), + Self::Vec4(elem) => write!(f, "vec4<{elem}>"), + Self::Vec3(elem) => write!(f, "vec3<{elem}>"), + Self::Vec2(elem) => write!(f, "vec2<{elem}>"), + Self::Scalar(elem) => write!(f, "{elem}"), } } } @@ -271,35 +258,36 @@ fn format_number(num: f64) -> String { impl Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Variable::GlobalInputArray(number, _) => { + Self::GlobalInputArray(number, _) => { write!(f, "input_{number}_global") } - Variable::LocalScalar { + Self::LocalScalar { id: index, depth: scope_depth, .. } => write!(f, "s_{scope_depth}_{index}"), - Variable::Local { + Self::Local { id: index, depth: scope_depth, .. } => write!(f, "l_{scope_depth}_{index}"), - Variable::LocalBinding { id: index, .. } => write!(f, "_{index}"), - Variable::Named { name, .. } => f.write_str(name), - Variable::Slice { + Self::LocalBinding { id: index, .. } => write!(f, "_{index}"), + Self::Named { name, .. } => f.write_str(name), + Self::Slice { id: index, - item: _, + depth: scope_depth, + .. } => write!(f, "slice_{scope_depth}_{index}"), - Variable::GlobalOutputArray(number, _) => { + Self::GlobalOutputArray(number, _) => { write!(f, "output_{number}_global") } - Variable::GlobalScalar(number, _, elem) => { + Self::GlobalScalar(number, _, elem) => { write!(f, "scalars_{elem}[{number}]") } // We do the conversion in Rust and then render the number to avoid overflow or other // precision related problems. - Variable::ConstantScalar(number, _elem) => match number { + Self::ConstantScalar(number, _elem) => match number { ConstantScalarValue::Int(val, kind) => match kind { IntKind::I32 => write!(f, "{}i", *val as i32), IntKind::I64 => write!(f, "{}i", { *val }), @@ -316,35 +304,35 @@ impl Display for Variable { ConstantScalarValue::UInt(val) => write!(f, "{}u", *val as u32), ConstantScalarValue::Bool(val) => write!(f, "{}", val), }, - Variable::SharedMemory(number, _, _) => { + Self::SharedMemory(number, _, _) => { write!(f, "shared_memory_{number}") } - Variable::ConstantArray(number, _, _) => write!(f, "arrays_{number}"), - Variable::LocalArray(number, _, scope_depth, _) => { + Self::ConstantArray(number, _, _) => write!(f, "arrays_{number}"), + Self::LocalArray(number, _, scope_depth, _) => { write!(f, "a_{scope_depth}_{number}") } - Variable::Id => f.write_str("id"), - Variable::LocalInvocationIndex => f.write_str("local_idx"), - Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"), - Variable::LocalInvocationIdY => f.write_str("local_invocation_id.y"), - Variable::LocalInvocationIdZ => f.write_str("local_invocation_id.z"), - Variable::Rank => f.write_str("rank"), - Variable::WorkgroupId => f.write_str("workgroup_id_no_axis"), - Variable::WorkgroupIdX => f.write_str("workgroup_id.x"), - Variable::WorkgroupIdY => f.write_str("workgroup_id.y"), - Variable::WorkgroupIdZ => f.write_str("workgroup_id.z"), - Variable::GlobalInvocationIdX => f.write_str("global_id.x"), - Variable::GlobalInvocationIdY => f.write_str("global_id.y"), - Variable::GlobalInvocationIdZ => f.write_str("global_id.z"), - Variable::WorkgroupSizeX => f.write_str("WORKGROUP_SIZE_X"), - Variable::WorkgroupSizeY => f.write_str("WORKGROUP_SIZE_Y"), - Variable::WorkgroupSizeZ => f.write_str("WORKGROUP_SIZE_Z"), - Variable::NumWorkgroupsX => f.write_str("num_workgroups.x"), - Variable::NumWorkgroupsY => f.write_str("num_workgroups.y"), - Variable::NumWorkgroupsZ => f.write_str("num_workgroups.z"), - Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"), - Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"), - Variable::SubgroupSize => f.write_str("subgroup_size"), + Self::Id => f.write_str("id"), + Self::LocalInvocationIndex => f.write_str("local_idx"), + Self::LocalInvocationIdX => f.write_str("local_invocation_id.x"), + Self::LocalInvocationIdY => f.write_str("local_invocation_id.y"), + Self::LocalInvocationIdZ => f.write_str("local_invocation_id.z"), + Self::Rank => f.write_str("rank"), + Self::WorkgroupId => f.write_str("workgroup_id_no_axis"), + Self::WorkgroupIdX => f.write_str("workgroup_id.x"), + Self::WorkgroupIdY => f.write_str("workgroup_id.y"), + Self::WorkgroupIdZ => f.write_str("workgroup_id.z"), + Self::GlobalInvocationIdX => f.write_str("global_id.x"), + Self::GlobalInvocationIdY => f.write_str("global_id.y"), + Self::GlobalInvocationIdZ => f.write_str("global_id.z"), + Self::WorkgroupSizeX => f.write_str("WORKGROUP_SIZE_X"), + Self::WorkgroupSizeY => f.write_str("WORKGROUP_SIZE_Y"), + Self::WorkgroupSizeZ => f.write_str("WORKGROUP_SIZE_Z"), + Self::NumWorkgroupsX => f.write_str("num_workgroups.x"), + Self::NumWorkgroupsY => f.write_str("num_workgroups.y"), + Self::NumWorkgroupsZ => f.write_str("num_workgroups.z"), + Self::WorkgroupSize => f.write_str("workgroup_size_no_axis"), + Self::NumWorkgroups => f.write_str("num_workgroups_no_axis"), + Self::SubgroupSize => f.write_str("subgroup_size"), } } } @@ -366,7 +354,7 @@ impl Display for IndexedVariable { impl Variable { pub fn fmt_left(&self) -> String { match self { - Variable::LocalBinding { id, .. } => { + Self::LocalBinding { id, .. } => { format!("let _{id}") } var => format!("{}", var), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 208ae3e76..dc3802651 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -934,7 +934,7 @@ fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec { + wgsl::Instruction::Powf { rhs, out, .. } => { register_extension(wgsl::Extension::PowfPrimitive(out.item())); if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 { @@ -943,17 +943,14 @@ fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec { + wgsl::Instruction::Erf { input, .. } => { register_extension(wgsl::Extension::Erf(input.item())); } #[cfg(target_os = "macos")] - wgsl::Instruction::Tanh { input, out: _ } => { + wgsl::Instruction::Tanh { input, .. } => { register_extension(wgsl::Extension::SafeTanh(input.item())) } - wgsl::Instruction::If { - cond: _, - instructions, - } => { + wgsl::Instruction::If { instructions, .. } => { for extension in register_extensions(instructions) { register_extension(extension); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs index 5492c94f2..d91234657 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs @@ -15,12 +15,12 @@ pub enum Extension { impl Display for Extension { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Extension::PowfScalar(elem) => format_powf_scalar(f, elem), - Extension::PowfPrimitive(elem) => format_powf_primitive(f, elem), - Extension::Powf(elem) => format_powf(f, elem), - Extension::Erf(elem) => format_erf(f, elem), + Self::PowfScalar(elem) => format_powf_scalar(f, elem), + Self::PowfPrimitive(elem) => format_powf_primitive(f, elem), + Self::Powf(elem) => format_powf(f, elem), + Self::Erf(elem) => format_erf(f, elem), #[cfg(target_os = "macos")] - Extension::SafeTanh(elem) => format_safe_tanh(f, elem), + Self::SafeTanh(elem) => format_safe_tanh(f, elem), } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 3beffc54e..5b12475a6 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -359,11 +359,11 @@ pub enum Instruction { impl Display for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Instruction::DeclareVariable { var } => { + Self::DeclareVariable { var } => { let item = var.item(); writeln!(f, "var {var}: {item};") } - Instruction::Add { lhs, rhs, out } => { + Self::Add { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular addition on atomic"); writeln!(f, "atomicAdd({out}, {rhs});") @@ -372,7 +372,7 @@ impl Display for Instruction { writeln!(f, "{out} = {lhs} + {rhs};") } } - Instruction::Slice { + Self::Slice { input, start, end, @@ -382,11 +382,11 @@ impl Display for Instruction { writeln!(f, "let {out}_length = {end} - {start};")?; writeln!(f, "let {out}_ptr = &{input};") } - Instruction::Fma { a, b, c, out } => { + Self::Fma { a, b, c, out } => { let out = out.fmt_left(); writeln!(f, "{out} = fma({a}, {b}, {c});") } - Instruction::Min { lhs, rhs, out } => { + Self::Min { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular min on atomic"); writeln!(f, "atomicMin({out}, {rhs});") @@ -395,7 +395,7 @@ impl Display for Instruction { writeln!(f, "{out} = min({lhs}, {rhs});") } } - Instruction::Max { lhs, rhs, out } => { + Self::Max { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular max on atomic"); writeln!(f, "atomicMax({out}, {rhs});") @@ -404,7 +404,7 @@ impl Display for Instruction { writeln!(f, "{out} = max({lhs}, {rhs});") } } - Instruction::And { lhs, rhs, out } => { + Self::And { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular and on atomic"); writeln!(f, "atomicAnd({out}, {rhs});") @@ -413,7 +413,7 @@ impl Display for Instruction { writeln!(f, "{out} = {lhs} && {rhs};") } } - Instruction::Or { lhs, rhs, out } => { + Self::Or { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular or on atomic"); writeln!(f, "atomicOr({out}, {rhs});") @@ -422,11 +422,11 @@ impl Display for Instruction { writeln!(f, "{out} = {lhs} || {rhs};") } } - Instruction::Not { input, out } => { + Self::Not { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = !{input};") } - Instruction::Index { lhs, rhs, out } => match lhs { + Self::Index { lhs, rhs, out } => match lhs { Variable::Slice { item, .. } => { let offset = Variable::Named { name: format!("{lhs}_offset"), @@ -442,7 +442,7 @@ impl Display for Instruction { } _ => index(f, lhs, rhs, out, None), }, - Instruction::Copy { + Self::Copy { input, in_index, out, @@ -462,7 +462,7 @@ impl Display for Instruction { }; writeln!(f, "{lhs} = {rhs};") } - Instruction::CopyBulk { + Self::CopyBulk { input, in_index, out, @@ -486,11 +486,11 @@ impl Display for Instruction { } Ok(()) } - Instruction::Modulo { lhs, rhs, out } => { + Self::Modulo { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} % {rhs};") } - Instruction::Remainder { lhs, rhs, out } => { + Self::Remainder { lhs, rhs, out } => { let f_type = match lhs.item() { Item::Vec4(_) => Item::Vec4(Elem::F32), Item::Vec3(_) => Item::Vec3(Elem::F32), @@ -504,7 +504,7 @@ impl Display for Instruction { let floor = f_type.fmt_cast_to(ty, format!("floor({lhs} / {rhs})")); writeln!(f, "{out} = {lhs} - {rhs} * {floor};") } - Instruction::Sub { lhs, rhs, out } => { + Self::Sub { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular sub on atomic"); writeln!(f, "atomicSub({out}, {rhs});") @@ -513,27 +513,27 @@ impl Display for Instruction { writeln!(f, "{out} = {lhs} - {rhs};") } } - Instruction::Mul { lhs, rhs, out } => { + Self::Mul { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} * {rhs};") } - Instruction::Div { lhs, rhs, out } => { + Self::Div { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} / {rhs};") } - Instruction::Abs { input, out } => { + Self::Abs { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = abs({input});") } - Instruction::Exp { input, out } => { + Self::Exp { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = exp({input});") } - Instruction::Log { input, out } => { + Self::Log { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = log({input});") } - Instruction::Clamp { + Self::Clamp { input, min_value, max_value, @@ -544,7 +544,7 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = clamp({input}, {min}, {max});") } - Instruction::Powf { lhs, rhs, out } => { + Self::Powf { lhs, rhs, out } => { if rhs.is_always_scalar() || rhs.item().vectorization_factor() == 1 { let out = out.fmt_left(); writeln!(f, "{out} = powf_scalar({lhs}, {rhs});") @@ -553,23 +553,23 @@ impl Display for Instruction { writeln!(f, "{out} = powf({lhs}, {rhs});") } } - Instruction::Sqrt { input, out } => { + Self::Sqrt { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = sqrt({input});") } - Instruction::Log1p { input, out } => { + Self::Log1p { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = log({input} + 1.0);") } - Instruction::Cos { input, out } => { + Self::Cos { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = cos({input});") } - Instruction::Sin { input, out } => { + Self::Sin { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = sin({input});") } - Instruction::Tanh { input, out } => { + Self::Tanh { input, out } => { let out = out.fmt_left(); #[cfg(target_os = "macos")] let result = writeln!(f, "{out} = safe_tanh({input});"); @@ -578,21 +578,21 @@ impl Display for Instruction { result } - Instruction::Erf { input, out } => { + Self::Erf { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = erf({input});") } - Instruction::Recip { input, out } => { + Self::Recip { input, out } => { let out = out.fmt_left(); write!(f, "{out} = 1.0 / {input};") } - Instruction::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f), - Instruction::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f), - Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f), - Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f), - Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f), - Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f), - Instruction::Assign { input, out } => { + Self::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f), + Self::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f), + Self::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f), + Self::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f), + Self::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f), + Self::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f), + Self::Assign { input, out } => { let vec_left = out.item().vectorization_factor(); let vec_right = input.item().vectorization_factor(); if out.elem().is_atomic() { @@ -616,18 +616,18 @@ impl Display for Instruction { writeln!(f, "{out} = {input};") } } - Instruction::Stride { dim, position, out } => { + Self::Stride { dim, position, out } => { let out = out.fmt_left(); writeln!(f, "{out} = info[({position}u * rank_2) + {dim} + 1u];") } - Instruction::Shape { dim, position, out } => { + Self::Shape { dim, position, out } => { let out = out.fmt_left(); writeln!( f, "{out} = info[({position}u * rank_2) + rank + {dim} + 1u];" ) } - Instruction::RangeLoop { + Self::RangeLoop { i, start, end, @@ -654,7 +654,7 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ f.write_str("}\n") } - Instruction::IndexAssign { lhs, rhs, out } => { + Self::IndexAssign { lhs, rhs, out } => { if let Variable::Slice { item, .. } = out { let offset = Variable::Named { name: format!("{out}_offset"), @@ -672,14 +672,14 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ index_assign(f, lhs, rhs, out, None) } } - Instruction::If { cond, instructions } => { + Self::If { cond, instructions } => { writeln!(f, "if {cond} {{")?; for i in instructions { write!(f, "{i}")?; } f.write_str("}\n") } - Instruction::IfElse { + Self::IfElse { cond, instructions_if, instructions_else, @@ -694,7 +694,7 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ } f.write_str("}\n") } - Instruction::Select { + Self::Select { cond, then, or_else, @@ -724,7 +724,7 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ writeln!(f, "{out} = select({or_else}, {then}, {cond});") } } - Instruction::Switch { + Self::Switch { value, instructions_default, cases, @@ -743,102 +743,102 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ } f.write_str("}\n}\n") } - Instruction::Return => f.write_str("return;\n"), - Instruction::Break => f.write_str("break;\n"), - Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"), - Instruction::StorageBarrier => f.write_str("storageBarrier();\n"), - Instruction::Length { var, out } => { + Self::Return => f.write_str("return;\n"), + Self::Break => f.write_str("break;\n"), + Self::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"), + Self::StorageBarrier => f.write_str("storageBarrier();\n"), + Self::Length { var, out } => { let out = out.fmt_left(); match var { Variable::Slice { .. } => writeln!(f, "{out} = {var}_length;"), _ => writeln!(f, "{out} = arrayLength(&{var});"), } } - Instruction::Loop { instructions } => { + Self::Loop { instructions } => { writeln!(f, "loop {{")?; for i in instructions { write!(f, "{i}")?; } f.write_str("}\n") } - Instruction::BitwiseOr { lhs, rhs, out } => { + Self::BitwiseOr { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} | {rhs};") } - Instruction::BitwiseAnd { lhs, rhs, out } => { + Self::BitwiseAnd { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} & {rhs};") } - Instruction::BitwiseXor { lhs, rhs, out } => { + Self::BitwiseXor { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} ^ {rhs};") } - Instruction::ShiftLeft { lhs, rhs, out } => { + Self::ShiftLeft { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} << {rhs};") } - Instruction::ShiftRight { lhs, rhs, out } => { + Self::ShiftRight { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = {lhs} >> {rhs};") } - Instruction::Round { input, out } => { + Self::Round { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = round({input});") } - Instruction::Floor { input, out } => { + Self::Floor { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = floor({input});") } - Instruction::Ceil { input, out } => { + Self::Ceil { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = ceil({input});") } - Instruction::Subgroup(op) => write!(f, "{op}"), - Instruction::Bitcast { input, out } => { + Self::Subgroup(op) => write!(f, "{op}"), + Self::Bitcast { input, out } => { let elem = out.item(); let out = out.fmt_left(); writeln!(f, "{out} = bitcast<{elem}>({input});") } - Instruction::AtomicLoad { input, out } => { + Self::AtomicLoad { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = atomicLoad({input});") } - Instruction::AtomicStore { input, out } => { + Self::AtomicStore { input, out } => { writeln!(f, "atomicStore({out},{input});") } - Instruction::AtomicSwap { lhs, rhs, out } => { + Self::AtomicSwap { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicExchange({lhs}, {rhs});") } - Instruction::AtomicAdd { lhs, rhs, out } => { + Self::AtomicAdd { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicAdd({lhs}, {rhs});") } - Instruction::AtomicSub { lhs, rhs, out } => { + Self::AtomicSub { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicSub({lhs}, {rhs});") } - Instruction::AtomicMax { lhs, rhs, out } => { + Self::AtomicMax { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicMax({lhs}, {rhs});") } - Instruction::AtomicMin { lhs, rhs, out } => { + Self::AtomicMin { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicMin({lhs}, {rhs});") } - Instruction::AtomicAnd { lhs, rhs, out } => { + Self::AtomicAnd { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicAnd({lhs}, {rhs});") } - Instruction::AtomicOr { lhs, rhs, out } => { + Self::AtomicOr { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicOr({lhs}, {rhs});") } - Instruction::AtomicXor { lhs, rhs, out } => { + Self::AtomicXor { lhs, rhs, out } => { let out = out.fmt_left(); write!(f, "{out} = atomicXor({lhs}, {rhs});") } - Instruction::AtomicCompareExchangeWeak { + Self::AtomicCompareExchangeWeak { lhs, cmp, value, @@ -851,15 +851,15 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ "{out} = atomicCompareExchangeWeak({lhs}, {cmp}, {value}).old_value;" ) } - Instruction::Negate { input, out } => { + Self::Negate { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = -{input};") } - Instruction::Magnitude { input, out } => { + Self::Magnitude { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = length({input});") } - Instruction::Normalize { input, out } => { + Self::Normalize { input, out } => { if input.item().vectorization_factor() == 1 { // We need a check for vectorization factor 1 here, for compatibility with cuda. // You can almost use sign here, however that does not correctly handle the case for x == 0.0. @@ -872,7 +872,7 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ writeln!(f, "{out} = normalize({input});") } } - Instruction::Dot { lhs, rhs, out } => { + Self::Dot { lhs, rhs, out } => { let out = out.fmt_left(); if lhs.item().vectorization_factor() == 1 { writeln!(f, "{out} = {lhs} * {rhs};") @@ -880,7 +880,7 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ writeln!(f, "{out} = dot({lhs}, {rhs});") } } - Instruction::VecInit { inputs, out } => { + Self::VecInit { inputs, out } => { let item = out.item(); let inputs = inputs.iter().map(|var| var.to_string()).collect::>(); let out = out.fmt_left(); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs index 95a177049..a250b8995 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs @@ -250,8 +250,8 @@ var<{}, {}> {}: {}; impl Display for Location { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Location::Storage => f.write_str("storage"), - Location::Workgroup => f.write_str("workgroup"), + Self::Storage => f.write_str("storage"), + Self::Workgroup => f.write_str("workgroup"), } } } @@ -260,7 +260,7 @@ impl Display for Visibility { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { #[cfg(exclusive_memory_only)] - Visibility::Read => f.write_str("read"), + Self::Read => f.write_str("read"), _ => f.write_str("read_write"), } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs index d9c8f474d..6b2b00a5c 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs @@ -41,32 +41,32 @@ pub enum Subgroup { impl Display for Subgroup { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Subgroup::Elect { out } => writeln!(f, "{out} = subgroupElect();"), - Subgroup::All { input, out } => { + Self::Elect { out } => writeln!(f, "{out} = subgroupElect();"), + Self::All { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupAll({input});") } - Subgroup::Any { input, out } => { + Self::Any { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupAny({input});") } - Subgroup::Broadcast { lhs, rhs, out } => { + Self::Broadcast { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupBroadcast({lhs}, {rhs});") } - Subgroup::Sum { input, out } => { + Self::Sum { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupAdd({input});") } - Subgroup::Prod { input, out } => { + Self::Prod { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMul({input});") } - Subgroup::Min { input, out } => { + Self::Min { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMin({input});") } - Subgroup::Max { input, out } => { + Self::Max { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMax({input});") } diff --git a/crates/cubecl-wgpu/src/device.rs b/crates/cubecl-wgpu/src/device.rs index e09645e7c..519ead20e 100644 --- a/crates/cubecl-wgpu/src/device.rs +++ b/crates/cubecl-wgpu/src/device.rs @@ -10,7 +10,7 @@ /// let device_gpu_1 = WgpuDevice::DiscreteGpu(0); // First discrete GPU found. /// let device_gpu_2 = WgpuDevice::DiscreteGpu(1); // Second discrete GPU found. /// ``` -#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)] pub enum WgpuDevice { /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list /// of all discrete GPUs found on the system. @@ -40,6 +40,7 @@ pub enum WgpuDevice { /// /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over /// `IntegratedGpu` since it's often a discrete GPU. + #[default] BestAvailable, /// Use an externally created, existing, wgpu setup. This is helpful when using CubeCL in conjunction @@ -48,9 +49,3 @@ pub enum WgpuDevice { /// The device is indexed by the global wgpu [adapter ID](wgpu::Device::global_id). Existing(wgpu::Id), } - -impl Default for WgpuDevice { - fn default() -> Self { - Self::BestAvailable - } -}