Skip to content

Commit

Permalink
Add some docs
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
1 parent e127325 commit 1a84818
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/burn-jit/src/fusion/on_write/trace/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
BoolElement, JitRuntime,
};

/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context).
pub struct LaunchPlanExecutor<'a, R: JitRuntime> {
scalars: &'a BTreeMap<ElemwisePrecision, u32>,
reshapes: &'a Vec<Reshape>,
Expand Down Expand Up @@ -84,6 +85,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
Runner::run(runner, client, inputs, outputs, &config)
.map_err(|err| (err, plan.handle_inputs, plan.handle_outputs))
}

fn register_inputs<'h>(
&self,
context: &mut Context<'_, JitFusionHandle<R>>,
Expand Down Expand Up @@ -146,6 +148,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> {
}
}

// Reshape values are pushed in reverse in the same scalar buffer for all `u32`
for relative in self.reshapes.iter().rev() {
let global = context.tensors.get(&relative.reshaped).unwrap();

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/fusion/on_write/trace/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::marker::PhantomData;

use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors};

/// Fetch and register [input handles](HandleInput) and itendify potential inputs that
/// can be used inplace.
pub struct InputsPlanner<'a, R: JitRuntime> {
inputs: &'a RegisteredTensors,
inputs_unhandled: &'a Vec<TensorId>,
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-jit/src/fusion/on_write/trace/outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use super::{
};
use std::collections::BTreeMap;

/// Create or reuse handles for the outputs.
///
/// It is also responsable to select the reference tensor.
pub struct OutputsPlanner<'a, R: JitRuntime> {
inputs: &'a RegisteredTensors,
reshapes: &'a Vec<Reshape>,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/fusion/on_write/trace/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::{
};
use burn_tensor::repr::{TensorDescription, TensorId};

/// The plan is responsable to keep runtime information related to the launch of a fused kernel
/// at one place.
#[derive(Debug)]
pub(crate) struct LaunchPlan<'a, R: JitRuntime> {
pub potential_inplaces: Vec<PotentialInplace<'a>>,
Expand Down

0 comments on commit 1a84818

Please sign in to comment.