Skip to content

Commit

Permalink
JIT migration: cast kernel (#1423)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Mar 7, 2024
1 parent 9eecc71 commit 040cd55
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 83 deletions.
115 changes: 115 additions & 0 deletions crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::{any::TypeId, marker::PhantomData};

use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Compiler, JitElement, Runtime,
};

/// Cast a tensor to the given element type.
///
/// Note: When input element is semantically a boolean, prefer bool_cast function.
pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
tensor: JitTensor<R, EI, D>,
) -> JitTensor<R, EO, D> {
if TypeId::of::<EI>() == TypeId::of::<EO>() {
return JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}

let kernel = CastEagerKernel::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
tensor.client.clone(),
tensor.device,
tensor.shape.clone(),
buffer,
);

execute_dynamic::<R, CastEagerKernel<R, EI, EO>, u32>(
&[EagerHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);

output
}

pub(crate) struct CastShader {
tensor: Variable,
output: Variable,
}

#[derive(new)]
pub(crate) struct CastEagerKernel<R: Runtime, EI: JitElement, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_in: PhantomData<EI>,
_elem_out: PhantomData<EO>,
}

impl<R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
for CastEagerKernel<R, EI, EO>
{
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item_input = EI::gpu_elem().into();
let item_output = EO::gpu_elem().into();

let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);

CastShader { tensor, output }.expand(&mut scope);

scope.write_global_custom(output);

let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};

let out = OutputInfo::Array { item: item_output };

let info = CompilationInfo {
inputs: vec![tensor],
outputs: vec![out],
scope,
};

let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}

fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>())
}
}

impl CastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::Id;
let output = self.output;

let value = scope.create_local(output.item());
gpu!(scope, value = tensor[id]);
gpu!(scope, output[id] = value);
}
}
Original file line number Diff line number Diff line change
@@ -1,74 +1,15 @@
use super::{
DynamicKernelSource, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
};
use std::marker::PhantomData;

use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
OutputInfo, WorkgroupLaunch,
},
compute::StaticKernel,
element::JitElement,
kernel::elemwise_workgroup,
kernel_wgsl,
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
Compiler, JitElement, Runtime,
};
use std::{any::TypeId, marker::PhantomData};

kernel_wgsl!(CastRaw, "../template/cast.wgsl");

struct Cast<InputElem: JitElement, OutputElem: JitElement> {
_i: PhantomData<InputElem>,
_o: PhantomData<OutputElem>,
}

impl<InputElem: JitElement, OutputElem: JitElement> StaticKernelSource
for Cast<InputElem, OutputElem>
{
fn source() -> SourceTemplate {
CastRaw::source()
.register("input_elem", InputElem::type_name())
.register("output_elem", OutputElem::type_name())
}
}

/// Cast a tensor to the given element type.
pub fn cast<R: Runtime, InputElem: JitElement, OutputElem: JitElement, const D: usize>(
tensor: JitTensor<R, InputElem, D>,
) -> JitTensor<R, OutputElem, D> {
if TypeId::of::<InputElem>() == TypeId::of::<OutputElem>() {
return JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}

let num_elems = tensor.shape.num_elements();
let kernel = StaticKernel::<
KernelSettings<
Cast<InputElem, OutputElem>,
f32,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));

let handle = tensor
.client
.empty(num_elems * core::mem::size_of::<OutputElem>());
let output = JitTensor::new(
tensor.client.clone(),
tensor.device,
tensor.shape.clone(),
handle,
);

tensor
.client
.execute(Box::new(kernel), &[&tensor.handle, &output.handle]);

output
}

/// Cast a bool tensor to the given element type.
///
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-jit/src/kernel/cast/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod base;
mod bool_cast;

pub use base::*;
pub use bool_cast::*;
17 changes: 0 additions & 17 deletions crates/burn-jit/src/template/cast.wgsl

This file was deleted.

0 comments on commit 040cd55

Please sign in to comment.