From 651bb4ca1c55a46403e0e7ea655b27b279eb94be Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 11:52:47 -0500 Subject: [PATCH] Clippy + Fmt --- .../burn-jit/src/fusion/on_write/builder.rs | 1 + .../src/fusion/on_write/trace/base.rs | 36 +++++++++---------- .../src/fusion/on_write/trace/builder.rs | 34 ++++++++++-------- .../src/fusion/on_write/trace/executor.rs | 11 ++++-- .../on_write/trace/{inputs.rs => input.rs} | 9 +++-- .../burn-jit/src/fusion/on_write/trace/mod.rs | 4 +-- .../on_write/trace/{outputs.rs => output.rs} | 8 +++-- .../src/fusion/on_write/trace/plan.rs | 2 +- .../src/fusion/on_write/trace/runner.rs | 2 +- .../fusion/on_write/trace/vectorization.rs | 14 ++++---- 10 files changed, 67 insertions(+), 54 deletions(-) rename crates/burn-jit/src/fusion/on_write/trace/{inputs.rs => input.rs} (92%) rename crates/burn-jit/src/fusion/on_write/trace/{outputs.rs => output.rs} (98%) diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index bbfca0f3d2..ffb5dfcb79 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -512,6 +512,7 @@ impl FuseOnWriteBuilder { let mut updated = self.current_output_shape.clone(); + #[allow(clippy::needless_range_loop)] for i in 0..rank { let curr = self.current_output_shape[i]; let new = out.shape[i]; diff --git a/crates/burn-jit/src/fusion/on_write/trace/base.rs b/crates/burn-jit/src/fusion/on_write/trace/base.rs index c6298f4b4b..f017d6dc1b 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/base.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/base.rs @@ -6,8 +6,8 @@ use super::{ settings::FuseSettings, }, executor::LaunchPlanExecutor, - inputs::InputsPlanner, - outputs::OutputsPlanner, + input::InputPlanner, + output::OutputPlanner, vectorization::VectorizationPlanner, HandleInput, HandleOutput, LaunchPlan, TraceRunner, }; @@ -17,19 +17,19 @@ use cubecl::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -#[derive(new, Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug)] /// Trace containing all element wise operations as well as reads and writes. pub struct FuseOnWriteTrace { - outputs: RegisteredTensors, - inputs: RegisteredTensors, - settings: FuseSettings, - scalars: BTreeMap, - reshapes: Vec, - shape_ref: Vec, - ops: Vec, - reads: BTreeMap>, - writes: BTreeMap, - inputs_unhandled: Vec, + pub outputs: RegisteredTensors, + pub inputs: RegisteredTensors, + pub settings: FuseSettings, + pub scalars: BTreeMap, + pub reshapes: Vec, + pub shape_ref: Vec, + pub ops: Vec, + pub reads: BTreeMap>, + pub writes: BTreeMap, + pub inputs_unhandled: Vec, } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -49,7 +49,7 @@ impl FuseOnWriteTrace { ) -> Result<(), Runner::Error> { let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len()); - InputsPlanner::::new( + InputPlanner::::new( &self.inputs, &self.inputs_unhandled, &self.reshapes, @@ -58,7 +58,7 @@ impl FuseOnWriteTrace { ) .run(context, &mut plan); - OutputsPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) + OutputPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) .run::(client, device, context, &mut plan); VectorizationPlanner::::new(&self.reshapes, &self.reads, &self.settings) @@ -67,9 +67,9 @@ impl FuseOnWriteTrace { match LaunchPlanExecutor::::new(&self.scalars, &self.reshapes, &self.ops) .execute::<_, BT>(client, runner, context, plan) { - Err((err, handle_inputs, handle_outputs)) => { - self.rollback(context, handle_inputs, handle_outputs); - Err(err) + Err(err) => { + self.rollback(context, err.handles_input, err.handles_output); + Err(err.runner_error) } Ok(val) => Ok(val), } diff --git a/crates/burn-jit/src/fusion/on_write/trace/builder.rs b/crates/burn-jit/src/fusion/on_write/trace/builder.rs index 8d30ec0283..896d7f3b1c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/builder.rs @@ -102,8 +102,10 @@ impl FuseOnWriteTraceBuilder { let out = self.locals.create(precision, tensor.id); let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); - let reads = if !self.reads.contains_key(&tensor.id) { - self.reads.insert(tensor.id, Vec::with_capacity(1)); + let reads = if let std::collections::btree_map::Entry::Vacant(e) = + self.reads.entry(tensor.id) + { + e.insert(Vec::with_capacity(1)); self.reads.get_mut(&tensor.id).unwrap() } else { self.reads.get_mut(&tensor.id).unwrap() @@ -180,8 +182,8 @@ impl FuseOnWriteTraceBuilder { let index = self.reshapes.len(); self.reshapes.push(Reshape { - reshaped: output.id.clone(), - original: tensor.id.clone(), + reshaped: output.id, + original: tensor.id, }); let rank = output.shape.len(); @@ -195,12 +197,13 @@ impl FuseOnWriteTraceBuilder { shape, }; - let reads = if !self.reads.contains_key(&tensor.id) { - self.reads.insert(tensor.id, Vec::with_capacity(1)); - self.reads.get_mut(&tensor.id).unwrap() - } else { - self.reads.get_mut(&tensor.id).unwrap() - }; + let reads = + if let std::collections::btree_map::Entry::Vacant(e) = self.reads.entry(tensor.id) { + e.insert(Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { input, @@ -226,7 +229,7 @@ impl FuseOnWriteTraceBuilder { Arg::Scalar(new_index, precision) } - pub fn build(&self, shape: Vec) -> FuseOnWriteTrace { + pub fn build(&self, shape_ref: Vec) -> FuseOnWriteTrace { let inputs = self.inputs.clone(); let outputs = self.output_tensors(); let ops = self.ops.clone(); @@ -250,19 +253,20 @@ impl FuseOnWriteTraceBuilder { let reshapes = self.reshapes.clone(); let settings = self.settings; + let inputs_unhandled = self.inputs_unhandled.clone(); - FuseOnWriteTrace::new( + FuseOnWriteTrace { outputs, inputs, settings, scalars, reshapes, - shape, + shape_ref, ops, reads, writes, - self.inputs_unhandled.clone(), - ) + inputs_unhandled, + } } fn output_tensors(&self) -> RegisteredTensors { diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs index 9b1d719068..749e74340e 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -24,6 +24,13 @@ pub struct LaunchPlanExecutor<'a, R: JitRuntime> { _r: PhantomData, } +#[derive(new)] +pub struct ExecutionError> { + pub runner_error: Runner::Error, + pub handles_input: Vec>, + pub handles_output: Vec>, +} + impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { pub fn new( scalars: &'a BTreeMap, @@ -44,7 +51,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { runner: &Runner, context: &mut Context<'_, JitFusionHandle>, plan: LaunchPlan<'a, R>, - ) -> Result<(), (Runner::Error, Vec>, Vec>)> { + ) -> Result<(), ExecutionError> { let reference = match plan.reference { Some(reference) => reference, None => { @@ -83,7 +90,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { }; Runner::run(runner, client, inputs, outputs, &config) - .map_err(|err| (err, plan.handle_inputs, plan.handle_outputs)) + .map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs)) } fn register_inputs<'h>( diff --git a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs b/crates/burn-jit/src/fusion/on_write/trace/input.rs similarity index 92% rename from crates/burn-jit/src/fusion/on_write/trace/inputs.rs rename to crates/burn-jit/src/fusion/on_write/trace/input.rs index 54791ec225..a243fddac1 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/input.rs @@ -11,7 +11,7 @@ use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; /// Fetch and register [input handles](HandleInput) and itendify potential inputs that /// can be used inplace. -pub struct InputsPlanner<'a, R: JitRuntime> { +pub struct InputPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, reshapes: &'a Vec, @@ -20,7 +20,7 @@ pub struct InputsPlanner<'a, R: JitRuntime> { _r: PhantomData, } -impl<'a, R: JitRuntime> InputsPlanner<'a, R> { +impl<'a, R: JitRuntime> InputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, @@ -51,11 +51,10 @@ impl<'a, R: JitRuntime> InputsPlanner<'a, R> { && status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) - && self + && !self .reshapes .iter() - .find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) - .is_none() + .any(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) && self.shape_ref == &tensor_relative.shape { plan.potential_inplaces.push(PotentialInplace { diff --git a/crates/burn-jit/src/fusion/on_write/trace/mod.rs b/crates/burn-jit/src/fusion/on_write/trace/mod.rs index a3f0575299..64de887986 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/mod.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod executor; -pub(crate) mod inputs; -pub(crate) mod outputs; +pub(crate) mod input; +pub(crate) mod output; pub(crate) mod vectorization; mod base; diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/output.rs similarity index 98% rename from crates/burn-jit/src/fusion/on_write/trace/outputs.rs rename to crates/burn-jit/src/fusion/on_write/trace/output.rs index 66624142fb..0964974c7a 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/output.rs @@ -19,7 +19,7 @@ use std::collections::BTreeMap; /// Create or reuse handles for the outputs. /// /// It is also responsable to select the reference tensor. -pub struct OutputsPlanner<'a, R: JitRuntime> { +pub struct OutputPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, reshapes: &'a Vec, outputs_sorted: Vec>, @@ -40,7 +40,7 @@ enum OutputKind { Reshaped { reshape: Reshape }, } -impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { +impl<'a, R: JitRuntime> OutputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors, @@ -235,6 +235,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } + #[allow(clippy::too_many_arguments)] fn normal_output( &mut self, client: &ComputeClient, @@ -299,6 +300,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } + #[allow(clippy::too_many_arguments)] fn reshaped_output( &mut self, client: &ComputeClient, @@ -377,7 +379,7 @@ impl OutputPositionMapper { /// Returns the right position from the precision and the global position in all outputs. pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { self.map - .get(&precision) + .get(precision) .unwrap() .iter() .enumerate() diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs index 2fc68ba2be..89a11a188c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -25,7 +25,7 @@ pub(crate) struct LaunchPlan<'a, R: JitRuntime> { pub rank: usize, } -impl<'a, R: JitRuntime> LaunchPlan<'a, R> { +impl LaunchPlan<'_, R> { pub fn new( reads: &BTreeMap>, writes: &BTreeMap, diff --git a/crates/burn-jit/src/fusion/on_write/trace/runner.rs b/crates/burn-jit/src/fusion/on_write/trace/runner.rs index dc9e2a8f83..fc3109327d 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/runner.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/runner.rs @@ -113,7 +113,7 @@ fn vectorization_default<'a, R: JitRuntime>( let mut max_current = u8::MAX; for (handle, tensor) in handles_inputs.zip(inputs) { - match vectorization_input(&handle, tensor) { + match vectorization_input(handle, tensor) { Vect::Broadcated => vectorizations.insert(tensor.id, 1), Vect::Max(val) => { max_current = Ord::min(val, max_current); diff --git a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs index e216b6e525..ff775e3327 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs @@ -70,13 +70,13 @@ impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { handle.vectorization = *plan.vectorization.get(&handle.global_id).unwrap(); } for handle in plan.handle_outputs.iter_mut() { - match handle { - HandleOutput::Owned { - vectorization, - global_id, - .. - } => *vectorization = *plan.vectorization.get(&global_id).unwrap(), - _ => {} + if let HandleOutput::Owned { + vectorization, + global_id, + .. + } = handle + { + *vectorization = *plan.vectorization.get(global_id).unwrap() } } }