Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring #206

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions crates/cubecl-common/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub enum TimingMethod {
impl Display for TimingMethod {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TimingMethod::Full => f.write_str("full"),
TimingMethod::DeviceOnly => f.write_str("device_only"),
Self::Full => f.write_str("full"),
Self::DeviceOnly => f.write_str("device_only"),
}
}
}
Expand Down
44 changes: 10 additions & 34 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,8 @@ impl InputInfo {
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
InputInfo::Array {
item,
visibility: _,
} => *item,
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
Self::Array { item, .. } => *item,
Self::Scalar { elem, .. } => Item::new(*elem),
}
}
}
Expand All @@ -234,18 +231,9 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => *item,
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => *item,
OutputInfo::Array { item } => *item,
Self::ArrayWrite { item, .. }
| Self::InputArrayWrite { item, .. }
| Self::Array { item } => *item,
}
}
}
Expand Down Expand Up @@ -278,18 +266,9 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn elem_size<R: Runtime>(&self) -> usize {
let elem = match self {
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::Array { item } => bool_elem(item.elem()),
Self::ArrayWrite { item, .. }
| Self::InputArrayWrite { item, .. }
| Self::Array { item } => bool_elem(item.elem()),
};
<R::Compiler as Compiler>::elem_size(elem)
}
Expand Down Expand Up @@ -464,18 +443,15 @@ impl KernelIntegrator {
let (item, local, position) = match output {
OutputInfo::ArrayWrite { item, local, position } => (item, local, position),
OutputInfo::InputArrayWrite {
item: _,
input,
local: _,
position: _,
input, ..
} => {
assert_eq!(
*input, mapping.pos_input as u16,
"Can't use different inputs for the same output."
);
return;
}
OutputInfo::Array { item: _ } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
OutputInfo::Array { .. } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
};

let item = match self.input_bindings.get_mut(mapping.pos_input) {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl KernelBuilder {
pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
let index = match self.indices.get_mut(&elem) {
Some(index) => match self.inputs.get_mut(*index).unwrap() {
InputInfo::Scalar { elem: _, size } => {
InputInfo::Scalar { size, .. } => {
*size += 1;
*size as u16 - 1
}
Expand Down
20 changes: 10 additions & 10 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ pub enum ScalarState<T> {
impl<R: Runtime> TensorState<R> {
/// Push a new tensor to the state.
pub fn push(&mut self, tensor: &TensorHandleRef<'_, R>) {
if let TensorState::Empty = self {
*self = TensorState::Some {
if let Self::Empty = self {
*self = Self::Some {
bindings: Vec::with_capacity(1),
metadata: Vec::new(),
lengths: Vec::new(),
Expand All @@ -189,12 +189,12 @@ impl<R: Runtime> TensorState<R> {
};

let (bindings, metadata, lengths) = match self {
TensorState::Empty => panic!("Should be init"),
TensorState::Some {
Self::Empty => panic!("Should be init"),
Self::Some {
bindings,
metadata,
lengths,
runtime: _,
..
} => (bindings, metadata, lengths),
};

Expand Down Expand Up @@ -277,7 +277,7 @@ impl<R: Runtime> TensorState<R> {
bindings,
mut metadata,
lengths,
runtime: _,
..
} = self
{
if R::require_array_lengths() {
Expand All @@ -296,8 +296,8 @@ impl<T: NoUninit> ScalarState<T> {
/// Add a new scalar value to the state.
pub fn push(&mut self, val: T) {
match self {
ScalarState::Empty => *self = Self::Some(vec![val]),
ScalarState::Some(values) => values.push(val),
Self::Empty => *self = Self::Some(vec![val]),
Self::Some(values) => values.push(val),
}
}

Expand All @@ -307,8 +307,8 @@ impl<T: NoUninit> ScalarState<T> {
bindings: &mut Vec<Binding>,
) {
match self {
ScalarState::Empty => (),
ScalarState::Some(values) => {
Self::Empty => (),
Self::Some(values) => {
let handle = client.create(bytemuck::cast_slice(values));
bindings.push(handle.binding());
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub struct RangeExpand<I: Int> {

impl<I: Int> RangeExpand<I> {
pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
RangeExpand {
Self {
start,
end,
inclusive,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/container/array/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ mod new {
/// Create a new array of the given length.
#[allow(unused_variables)]
pub fn new<L: Index>(length: L) -> Self {
Array { _val: PhantomData }
Self { _val: PhantomData }
}

/// Create an array from data.
pub fn from_data<C: CubePrimitive>(_data: impl IntoIterator<Item = C>) -> Self {
Array { _val: PhantomData }
Self { _val: PhantomData }
}

/// Expand function of [new](Array::new).
Expand Down
8 changes: 2 additions & 6 deletions crates/cubecl-core/src/frontend/container/array/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ pub enum ArrayArg<'a, R: Runtime> {

impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
if let ArrayArg::Handle {
handle,
vectorization_factor: _,
} = self
{
if let Self::Handle { handle, .. } = self {
launcher.register_array(handle)
}
}
Expand Down Expand Up @@ -129,8 +125,8 @@ impl<C: CubePrimitive> LaunchArg for Array<C> {
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
ArrayArg::Handle {
handle: _,
vectorization_factor,
..
} => ArrayCompilationArg {
inplace: None,
vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/container/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl<T: CubePrimitive> CubeType for SharedMemory<T> {

impl<T: CubePrimitive + Clone> SharedMemory<T> {
pub fn new<S: Index>(_size: S) -> Self {
SharedMemory { _val: PhantomData }
Self { _val: PhantomData }
}

pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
Expand Down
8 changes: 2 additions & 6 deletions crates/cubecl-core/src/frontend/container/tensor/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ impl<C: CubePrimitive> LaunchArg for Tensor<C> {
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
TensorArg::Handle {
handle: _,
vectorization_factor,
..
} => TensorCompilationArg {
inplace: None,
vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
Expand Down Expand Up @@ -127,11 +127,7 @@ impl<'a, R: Runtime> TensorArg<'a, R> {

impl<'a, R: Runtime> ArgSettings<R> for TensorArg<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
if let TensorArg::Handle {
handle,
vectorization_factor: _,
} = self
{
if let Self::Handle { handle, .. } = self {
launcher.register_tensor(handle)
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,14 @@ impl ExpandElement {
/// If the element can be mutated inplace, potentially reusing the register.
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
Self::Managed(var) => {
if let Variable::Local { .. } = var.as_ref() {
Rc::strong_count(var) <= 2
} else {
false
}
}
ExpandElement::Plain(_) => false,
Self::Plain(_) => false,
}
}

Expand All @@ -343,8 +343,8 @@ impl core::ops::Deref for ExpandElement {

fn deref(&self) -> &Self::Target {
match self {
ExpandElement::Managed(var) => var.as_ref(),
ExpandElement::Plain(var) => var,
Self::Managed(var) => var.as_ref(),
Self::Plain(var) => var,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/element/vectorized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ impl<T: CubeType> Vectorized for &mut Tensor<T> {
impl Vectorized for ExpandElement {
fn vectorization_factor(&self) -> u32 {
let var = match self {
ExpandElement::Managed(var) => var,
ExpandElement::Plain(var) => var,
Self::Managed(var) => var,
Self::Plain(var) => var,
};

var.item().vectorization.map(|it| it.get()).unwrap_or(1) as u32
Expand Down
16 changes: 8 additions & 8 deletions crates/cubecl-core/src/ir/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ pub enum Branch {
impl Display for Branch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Branch::If(if_) => write!(f, "if({})", if_.cond),
Branch::IfElse(if_else) => write!(f, "if({})", if_else.cond),
Branch::Select(select) => write!(
Self::If(if_) => write!(f, "if({})", if_.cond),
Self::IfElse(if_else) => write!(f, "if({})", if_else.cond),
Self::Select(select) => write!(
f,
"{} = select({}, {}, {})",
select.out, select.cond, select.then, select.or_else
),
Branch::Switch(switch) => write!(
Self::Switch(switch) => write!(
f,
"switch({}) {:?}",
switch.value,
Expand All @@ -45,17 +45,17 @@ impl Display for Branch {
.map(|case| format!("{}", case.0))
.collect::<Vec<_>>(),
),
Branch::RangeLoop(range_loop) => write!(
Self::RangeLoop(range_loop) => write!(
f,
"for({} in {}{}{})",
range_loop.i,
range_loop.start,
if range_loop.inclusive { "..=" } else { ".." },
range_loop.end
),
Branch::Loop(_) => write!(f, "loop{{}}"),
Branch::Return => write!(f, "return"),
Branch::Break => write!(f, "break"),
Self::Loop(_) => write!(f, "loop{{}}"),
Self::Return => write!(f, "return"),
Self::Break => write!(f, "break"),
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/ir/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ pub enum CoopMma {
impl Display for CoopMma {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CoopMma::Fill { mat, value } => write!(f, "matrix_fill({}, {})", mat, value),
CoopMma::Load {
Self::Fill { mat, value } => write!(f, "matrix_fill({}, {})", mat, value),
Self::Load {
mat,
value,
stride,
Expand All @@ -81,7 +81,7 @@ impl Display for CoopMma {
mat, value, stride
)
}
CoopMma::Execute {
Self::Execute {
mat_a,
mat_b,
mat_c,
Expand All @@ -91,7 +91,7 @@ impl Display for CoopMma {
"{} = execute_cmma({}, {}, {})",
mat_d, mat_a, mat_b, mat_c
),
CoopMma::Store {
Self::Store {
output,
mat,
stride,
Expand Down
Loading