Skip to content

Commit

Permalink
Clippy + Fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
1 parent 5b25f18 commit 651bb4c
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 54 deletions.
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/on_write/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ impl FuseOnWriteBuilder {

let mut updated = self.current_output_shape.clone();

#[allow(clippy::needless_range_loop)]
for i in 0..rank {
let curr = self.current_output_shape[i];
let new = out.shape[i];
Expand Down
36 changes: 18 additions & 18 deletions crates/burn-jit/src/fusion/on_write/trace/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
settings::FuseSettings,
},
executor::LaunchPlanExecutor,
inputs::InputsPlanner,
outputs::OutputsPlanner,
input::InputPlanner,
output::OutputPlanner,
vectorization::VectorizationPlanner,
HandleInput, HandleOutput, LaunchPlan, TraceRunner,
};
Expand All @@ -17,19 +17,19 @@ use cubecl::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

#[derive(new, Clone, Serialize, Deserialize, Debug)]
#[derive(Clone, Serialize, Deserialize, Debug)]
/// Trace containing all element wise operations as well as reads and writes.
pub struct FuseOnWriteTrace {
outputs: RegisteredTensors,
inputs: RegisteredTensors,
settings: FuseSettings,
scalars: BTreeMap<ElemwisePrecision, u32>,
reshapes: Vec<Reshape>,
shape_ref: Vec<usize>,
ops: Vec<ElemwiseOp>,
reads: BTreeMap<TensorId, Vec<ElemwiseOp>>,
writes: BTreeMap<TensorId, ElemwiseOp>,
inputs_unhandled: Vec<TensorId>,
pub outputs: RegisteredTensors,
pub inputs: RegisteredTensors,
pub settings: FuseSettings,
pub scalars: BTreeMap<ElemwisePrecision, u32>,
pub reshapes: Vec<Reshape>,
pub shape_ref: Vec<usize>,
pub ops: Vec<ElemwiseOp>,
pub reads: BTreeMap<TensorId, Vec<ElemwiseOp>>,
pub writes: BTreeMap<TensorId, ElemwiseOp>,
pub inputs_unhandled: Vec<TensorId>,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
Expand All @@ -49,7 +49,7 @@ impl FuseOnWriteTrace {
) -> Result<(), Runner::Error> {
let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len());

InputsPlanner::<R>::new(
InputPlanner::<R>::new(
&self.inputs,
&self.inputs_unhandled,
&self.reshapes,
Expand All @@ -58,7 +58,7 @@ impl FuseOnWriteTrace {
)
.run(context, &mut plan);

OutputsPlanner::<R>::new(&self.inputs, &self.outputs, &self.reshapes)
OutputPlanner::<R>::new(&self.inputs, &self.outputs, &self.reshapes)
.run::<BT>(client, device, context, &mut plan);

VectorizationPlanner::<R>::new(&self.reshapes, &self.reads, &self.settings)
Expand All @@ -67,9 +67,9 @@ impl FuseOnWriteTrace {
match LaunchPlanExecutor::<R>::new(&self.scalars, &self.reshapes, &self.ops)
.execute::<_, BT>(client, runner, context, plan)
{
Err((err, handle_inputs, handle_outputs)) => {
self.rollback(context, handle_inputs, handle_outputs);
Err(err)
Err(err) => {
self.rollback(context, err.handles_input, err.handles_output);
Err(err.runner_error)
}
Ok(val) => Ok(val),
}
Expand Down
34 changes: 19 additions & 15 deletions crates/burn-jit/src/fusion/on_write/trace/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ impl FuseOnWriteTraceBuilder {
let out = self.locals.create(precision, tensor.id);
let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown);

let reads = if !self.reads.contains_key(&tensor.id) {
self.reads.insert(tensor.id, Vec::with_capacity(1));
let reads = if let std::collections::btree_map::Entry::Vacant(e) =
self.reads.entry(tensor.id)
{
e.insert(Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
Expand Down Expand Up @@ -180,8 +182,8 @@ impl FuseOnWriteTraceBuilder {

let index = self.reshapes.len();
self.reshapes.push(Reshape {
reshaped: output.id.clone(),
original: tensor.id.clone(),
reshaped: output.id,
original: tensor.id,
});
let rank = output.shape.len();

Expand All @@ -195,12 +197,13 @@ impl FuseOnWriteTraceBuilder {
shape,
};

let reads = if !self.reads.contains_key(&tensor.id) {
self.reads.insert(tensor.id, Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
};
let reads =
if let std::collections::btree_map::Entry::Vacant(e) = self.reads.entry(tensor.id) {
e.insert(Vec::with_capacity(1));
self.reads.get_mut(&tensor.id).unwrap()
} else {
self.reads.get_mut(&tensor.id).unwrap()
};

reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs {
input,
Expand All @@ -226,7 +229,7 @@ impl FuseOnWriteTraceBuilder {
Arg::Scalar(new_index, precision)
}

pub fn build(&self, shape: Vec<usize>) -> FuseOnWriteTrace {
pub fn build(&self, shape_ref: Vec<usize>) -> FuseOnWriteTrace {
let inputs = self.inputs.clone();
let outputs = self.output_tensors();
let ops = self.ops.clone();
Expand All @@ -250,19 +253,20 @@ impl FuseOnWriteTraceBuilder {

let reshapes = self.reshapes.clone();
let settings = self.settings;
let inputs_unhandled = self.inputs_unhandled.clone();

FuseOnWriteTrace::new(
FuseOnWriteTrace {
outputs,
inputs,
settings,
scalars,
reshapes,
shape,
shape_ref,
ops,
reads,
writes,
self.inputs_unhandled.clone(),
)
inputs_unhandled,
}
}

fn output_tensors(&self) -> RegisteredTensors {
Expand Down
11 changes: 9 additions & 2 deletions crates/burn-jit/src/fusion/on_write/trace/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ pub struct LaunchPlanExecutor<'a, R: JitRuntime> {
_r: PhantomData<R>,
}

#[derive(new)]
pub struct ExecutionError<R: JitRuntime, Runner: TraceRunner<R>> {
pub runner_error: Runner::Error,
pub handles_input: Vec<HandleInput<R>>,
pub handles_output: Vec<HandleOutput<R>>,
}

impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
pub fn new(
scalars: &'a BTreeMap<ElemwisePrecision, u32>,
Expand All @@ -44,7 +51,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
runner: &Runner,
context: &mut Context<'_, JitFusionHandle<R>>,
plan: LaunchPlan<'a, R>,
) -> Result<(), (Runner::Error, Vec<HandleInput<R>>, Vec<HandleOutput<R>>)> {
) -> Result<(), ExecutionError<R, Runner>> {
let reference = match plan.reference {
Some(reference) => reference,
None => {
Expand Down Expand Up @@ -83,7 +90,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
};

Runner::run(runner, client, inputs, outputs, &config)
.map_err(|err| (err, plan.handle_inputs, plan.handle_outputs))
.map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs))
}

fn register_inputs<'h>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors};

/// Fetch and register [input handles](HandleInput) and itendify potential inputs that
/// can be used inplace.
pub struct InputsPlanner<'a, R: JitRuntime> {
pub struct InputPlanner<'a, R: JitRuntime> {
inputs: &'a RegisteredTensors,
inputs_unhandled: &'a Vec<TensorId>,
reshapes: &'a Vec<Reshape>,
Expand All @@ -20,7 +20,7 @@ pub struct InputsPlanner<'a, R: JitRuntime> {
_r: PhantomData<R>,
}

impl<'a, R: JitRuntime> InputsPlanner<'a, R> {
impl<'a, R: JitRuntime> InputPlanner<'a, R> {
pub fn new(
inputs: &'a RegisteredTensors,
inputs_unhandled: &'a Vec<TensorId>,
Expand Down Expand Up @@ -51,11 +51,10 @@ impl<'a, R: JitRuntime> InputsPlanner<'a, R> {
&& status == &TensorStatus::ReadWrite
&& handle.handle.can_mut()
&& !self.inputs_unhandled.contains(&tensor_relative.id)
&& self
&& !self
.reshapes
.iter()
.find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id)
.is_none()
.any(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id)
&& self.shape_ref == &tensor_relative.shape
{
plan.potential_inplaces.push(PotentialInplace {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/fusion/on_write/trace/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub(crate) mod executor;
pub(crate) mod inputs;
pub(crate) mod outputs;
pub(crate) mod input;
pub(crate) mod output;
pub(crate) mod vectorization;

mod base;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::collections::BTreeMap;
/// Create or reuse handles for the outputs.
///
/// It is also responsable to select the reference tensor.
pub struct OutputsPlanner<'a, R: JitRuntime> {
pub struct OutputPlanner<'a, R: JitRuntime> {
inputs: &'a RegisteredTensors,
reshapes: &'a Vec<Reshape>,
outputs_sorted: Vec<OutputSorted<'a>>,
Expand All @@ -40,7 +40,7 @@ enum OutputKind {
Reshaped { reshape: Reshape },
}

impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
impl<'a, R: JitRuntime> OutputPlanner<'a, R> {
pub fn new(
inputs: &'a RegisteredTensors,
outputs: &'a RegisteredTensors,
Expand Down Expand Up @@ -235,6 +235,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
self.globals[output.pos_original] = Some(tensor_global);
}

#[allow(clippy::too_many_arguments)]
fn normal_output<BT: BoolElement>(
&mut self,
client: &ComputeClient<R::Server, R::Channel>,
Expand Down Expand Up @@ -299,6 +300,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
self.globals[output.pos_original] = Some(tensor_global);
}

#[allow(clippy::too_many_arguments)]
fn reshaped_output<BT: BoolElement>(
&mut self,
client: &ComputeClient<R::Server, R::Channel>,
Expand Down Expand Up @@ -377,7 +379,7 @@ impl OutputPositionMapper {
/// Returns the right position from the precision and the global position in all outputs.
pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 {
self.map
.get(&precision)
.get(precision)
.unwrap()
.iter()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/on_write/trace/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) struct LaunchPlan<'a, R: JitRuntime> {
pub rank: usize,
}

impl<'a, R: JitRuntime> LaunchPlan<'a, R> {
impl<R: JitRuntime> LaunchPlan<'_, R> {
pub fn new(
reads: &BTreeMap<TensorId, Vec<ElemwiseOp>>,
writes: &BTreeMap<TensorId, ElemwiseOp>,
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/on_write/trace/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ fn vectorization_default<'a, R: JitRuntime>(
let mut max_current = u8::MAX;

for (handle, tensor) in handles_inputs.zip(inputs) {
match vectorization_input(&handle, tensor) {
match vectorization_input(handle, tensor) {
Vect::Broadcated => vectorizations.insert(tensor.id, 1),
Vect::Max(val) => {
max_current = Ord::min(val, max_current);
Expand Down
14 changes: 7 additions & 7 deletions crates/burn-jit/src/fusion/on_write/trace/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> {
handle.vectorization = *plan.vectorization.get(&handle.global_id).unwrap();
}
for handle in plan.handle_outputs.iter_mut() {
match handle {
HandleOutput::Owned {
vectorization,
global_id,
..
} => *vectorization = *plan.vectorization.get(&global_id).unwrap(),
_ => {}
if let HandleOutput::Owned {
vectorization,
global_id,
..
} = handle
{
*vectorization = *plan.vectorization.get(global_id).unwrap()
}
}
}
Expand Down

0 comments on commit 651bb4c

Please sign in to comment.