diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs index 55c1c39912..9b1d719068 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -16,6 +16,7 @@ use crate::{ BoolElement, JitRuntime, }; +/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). pub struct LaunchPlanExecutor<'a, R: JitRuntime> { scalars: &'a BTreeMap, reshapes: &'a Vec, @@ -84,6 +85,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { Runner::run(runner, client, inputs, outputs, &config) .map_err(|err| (err, plan.handle_inputs, plan.handle_outputs)) } + fn register_inputs<'h>( &self, context: &mut Context<'_, JitFusionHandle>, @@ -146,6 +148,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { } } + // Reshape values are pushed in reverse in the same scalar buffer for all `u32` for relative in self.reshapes.iter().rev() { let global = context.tensors.get(&relative.reshaped).unwrap(); diff --git a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs index 0ae74c708d..54791ec225 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs @@ -9,6 +9,8 @@ use std::marker::PhantomData; use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; +/// Fetch and register [input handles](HandleInput) and itendify potential inputs that +/// can be used inplace. pub struct InputsPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs index bb84558091..66624142fb 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs @@ -16,6 +16,9 @@ use super::{ }; use std::collections::BTreeMap; +/// Create or reuse handles for the outputs. +/// +/// It is also responsable to select the reference tensor. pub struct OutputsPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, reshapes: &'a Vec, diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs index 896b8d7f51..2fc68ba2be 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -9,6 +9,8 @@ use crate::{ }; use burn_tensor::repr::{TensorDescription, TensorId}; +/// The plan is responsable to keep runtime information related to the launch of a fused kernel +/// at one place. #[derive(Debug)] pub(crate) struct LaunchPlan<'a, R: JitRuntime> { pub potential_inplaces: Vec>,