Skip to content

Commit

Permalink
Add reshape optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
1 parent b9bf504 commit 6ce7c3e
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 44 deletions.
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/on_write/trace/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl FuseOnWriteTrace {
)
.run(context, &mut plan);

OutputsPlanner::<R>::new(&self.inputs, &self.outputs)
OutputsPlanner::<R>::new(&self.inputs, &self.outputs, &self.reshapes)
.run::<BT>(client, device, context, &mut plan);

VectorizationPlanner::<R>::new(&self.reshapes, &self.reads, &self.settings)
Expand Down
17 changes: 13 additions & 4 deletions crates/burn-jit/src/fusion/on_write/trace/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
context: &mut Context<'_, JitFusionHandle<R>>,
plan: LaunchPlan<'a, R>,
) -> Result<(), (Runner::Error, Vec<HandleInput<R>>, Vec<HandleOutput<R>>)> {
let reference = match plan.reference {
Some(reference) => reference,
None => {
if plan.writes.is_empty() {
// Nothing to write, can skip execution.
return Ok(());
} else {
panic!("An output should exist for the fused kernel")
}
}
};

let inputs = self.register_inputs(context, &plan.handle_inputs);
let outputs = self.register_outputs::<BT>(&plan.handle_outputs);

Expand All @@ -65,10 +77,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {

let config = ElemwiseConfig {
rank: plan.rank as u32,
ref_layout: plan
.reference
.expect("An output should exist for the fused kernel")
.layout,
ref_layout: reference.layout,
ops,
};

Expand Down
169 changes: 130 additions & 39 deletions crates/burn-jit/src/fusion/on_write/trace/outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ use crate::{
on_write::ir::{Arg, ElemwiseOp, LayoutInfo},
strides_dyn_rank, JitFusionHandle,
},
tensor::is_contiguous,
BoolElement, JitRuntime,
};

use super::{super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors};
use super::{
super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors, Reshape,
};
use std::collections::BTreeMap;

pub struct OutputsPlanner<'a, R: JitRuntime> {
inputs: &'a RegisteredTensors,
reshapes: &'a Vec<Reshape>,
outputs_sorted: Vec<OutputSorted<'a>>,
handles: Vec<Option<HandleOutput<R>>>,
globals: Vec<Option<TensorDescription>>,
Expand All @@ -27,8 +31,18 @@ struct OutputSorted<'a> {
tensor_relative: &'a TensorDescription,
}

enum OutputKind {
Normal,
Inplace { input_pos: usize },
Reshaped { reshape: Reshape },
}

impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
pub fn new(inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors) -> Self {
pub fn new(
inputs: &'a RegisteredTensors,
outputs: &'a RegisteredTensors,
reshapes: &'a Vec<Reshape>,
) -> Self {
let mut mapper = OutputPositionMapper::default();
let mut outputs_sorted: Vec<_> = outputs
.iter()
Expand Down Expand Up @@ -61,6 +75,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
Self {
inputs,
outputs_sorted,
reshapes,
handles,
globals,
mapper,
Expand All @@ -72,7 +87,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
client: &ComputeClient<R::Server, R::Channel>,
device: &R::Device,
context: &mut Context<'_, JitFusionHandle<R>>,
analysis: &mut LaunchPlan<'a, R>,
plan: &mut LaunchPlan<'a, R>,
) {
// So that we can borrow self during the iteration.
let mut outputs = Vec::new();
Expand All @@ -86,30 +101,42 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
.clone();
let strides = strides_dyn_rank(&tensor_global.shape);

match Self::select_input_inplace(analysis, &tensor_global, &output, &strides) {
Some(index) => {
self.analyse_inplace(context, analysis, output, tensor_global, index);
match self.output_kind(plan, &tensor_global, &output, &strides) {
OutputKind::Inplace { input_pos } => {
self.inplace_output(context, plan, output, tensor_global, input_pos);
}
OutputKind::Normal => {
self.normal_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
strides,
);
}
None => {
self.analyse_output::<BT>(
OutputKind::Reshaped { reshape } => {
self.reshaped_output::<BT>(
client,
device,
context,
analysis,
plan,
output,
tensor_global,
strides,
reshape,
);
}
}
}

for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) {
analysis.handle_outputs.push(handle.unwrap());
analysis.global_outputs.push(global.unwrap());
plan.handle_outputs.push(handle.unwrap());
plan.global_outputs.push(global.unwrap());
}

Self::add_layout_info_inputs(analysis);
Self::add_layout_info_inputs(plan);
}

fn add_layout_info_inputs(analysis: &mut LaunchPlan<'_, R>) {
Expand All @@ -128,14 +155,24 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
}
}

fn select_input_inplace(
analysis: &mut LaunchPlan<'a, R>,
fn output_kind(
&self,
plan: &mut LaunchPlan<'a, R>,
tensor_global: &TensorDescription,
output: &OutputSorted,
strides: &[usize],
) -> Option<usize> {
analysis
.potential_inplaces
) -> OutputKind {
if let Some(reshape) = self
.reshapes
.iter()
.find(|r| r.reshaped == output.tensor_relative.id)
{
return OutputKind::Reshaped {
reshape: reshape.clone(),
};
}

plan.potential_inplaces
.iter()
.enumerate()
.find(|(_pos, pi)| {
Expand All @@ -144,45 +181,42 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
&& pi.strides == strides
})
.map(|(pos, _)| pos)
.map(|input_pos| OutputKind::Inplace { input_pos })
.unwrap_or(OutputKind::Normal)
}

fn analyse_inplace(
fn inplace_output(
&mut self,
context: &mut Context<'_, JitFusionHandle<R>>,
analysis: &mut LaunchPlan<'a, R>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorDescription,
input_index: usize,
) {
let potential_inplace = analysis.potential_inplaces.remove(input_index);
let handle_input = analysis
.handle_inputs
.get(potential_inplace.input_pos)
.unwrap();
let potential_inplace = plan.potential_inplaces.remove(input_index);
let handle_input = plan.handle_inputs.get(potential_inplace.input_pos).unwrap();

if analysis.reference.is_none() {
if plan.reference.is_none() {
let index_input = self
.inputs
.get_index(output.precision, potential_inplace.tensor_relative.id)
.unwrap();

analysis.reference = Some(Reference {
plan.reference = Some(Reference {
layout: Arg::Input(index_input as u32, output.precision, LayoutInfo::IsRef),
shape: tensor_global.shape.clone(),
strides: handle_input.handle.strides.clone(),
});

if let Some(ops) = analysis.reads.get_mut(&handle_input.relative_id) {
if let Some(ops) = plan.reads.get_mut(&handle_input.relative_id) {
for op in ops.iter_mut() {
if let ElemwiseOp::Assign(op) = op {
op.input.add_layout_info(LayoutInfo::IsRef);
};
}
}

if let Some(ElemwiseOp::Assign(op)) =
analysis.writes.get_mut(&output.tensor_relative.id)
{
if let Some(ElemwiseOp::Assign(op)) = plan.writes.get_mut(&output.tensor_relative.id) {
op.out.add_layout_info(LayoutInfo::IsRef);
};
}
Expand All @@ -198,35 +232,34 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
self.globals[output.pos_original] = Some(tensor_global);
}

fn analyse_output<BT: BoolElement>(
fn normal_output<BT: BoolElement>(
&mut self,
client: &ComputeClient<R::Server, R::Channel>,
device: &R::Device,
context: &mut Context<'_, JitFusionHandle<R>>,
analysis: &mut LaunchPlan<'a, R>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorDescription,
strides: Vec<usize>,
) {
if analysis.reference.is_none() {
if plan.reference.is_none() {
let position = self
.mapper
.resolve_index(&output.precision, output.pos_original);
analysis.reference = Some(Reference {
plan.reference = Some(Reference {
layout: Arg::Output(position, output.precision, LayoutInfo::IsRef),
shape: tensor_global.shape.clone(),
strides: strides.clone(),
});

if let ElemwiseOp::Assign(op) =
analysis.writes.get_mut(&output.tensor_relative.id).unwrap()
if let ElemwiseOp::Assign(op) = plan.writes.get_mut(&output.tensor_relative.id).unwrap()
{
op.out.add_layout_info(LayoutInfo::IsRef);
};
} else if let Some(reference) = analysis.reference.as_ref() {
} else if let Some(reference) = plan.reference.as_ref() {
if reference.strides == strides && reference.shape == tensor_global.shape {
if let ElemwiseOp::Assign(op) =
analysis.writes.get_mut(&output.tensor_relative.id).unwrap()
plan.writes.get_mut(&output.tensor_relative.id).unwrap()
{
op.out.add_layout_info(LayoutInfo::SameAsRef);
};
Expand All @@ -248,7 +281,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
dtype,
};

analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank);
plan.rank = usize::max(tensor_global.shape.len(), plan.rank);
context
.handles
.register_handle(tensor_global.id, handle.clone());
Expand All @@ -262,6 +295,64 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> {
});
self.globals[output.pos_original] = Some(tensor_global);
}

fn reshaped_output<BT: BoolElement>(
&mut self,
client: &ComputeClient<R::Server, R::Channel>,
device: &R::Device,
context: &mut Context<'_, JitFusionHandle<R>>,
plan: &mut LaunchPlan<'a, R>,
output: OutputSorted,
tensor_global: TensorDescription,
strides: Vec<usize>,
reshape: Reshape,
) {
let original_handle = plan
.handle_inputs
.iter()
.find(|handle| handle.relative_id == reshape.original)
.unwrap();

// We encode bool tensors as `B`.
let dtype = match tensor_global.dtype {
DType::Bool => BT::dtype(),
_ => tensor_global.dtype,
};

if is_contiguous(
&original_handle.global_shape,
&original_handle.handle.strides,
) {
plan.writes.remove(&output.tensor_relative.id);

let handle = JitFusionHandle {
client: client.clone(),
handle: original_handle.handle.handle.clone(),
device: device.clone(),
strides,
dtype,
};
context
.handles
.register_handle(tensor_global.id, handle.clone());
// IT will never be access, just a way to keep the original position working.
self.handles[output.pos_original] = Some(HandleOutput::Alias {
input_pos: 0,
precision: output.precision,
});
self.globals[output.pos_original] = Some(tensor_global);
} else {
self.normal_output::<BT>(
client,
device,
context,
plan,
output,
tensor_global,
strides,
);
}
}
}

/// Group output position by [element precision](ElemwisePrecision).
Expand Down

0 comments on commit 6ce7c3e

Please sign in to comment.