Skip to content

Commit

Permalink
Refactor: split JitKernel and SourceKernel (#1569)
Browse files Browse the repository at this point in the history
* refactor execute_dynamic into Execution

* minor change

* extension cfg

* jitkernel and sourcekernel

* add todo statement

* cleanup and docs

* update book

* fix server dependancy on compiler

* refactor into shader information

* refactor to compile shader once

* clippy

* clippy

* clippy

* fix doc

* fix doc

* fmt

* rename feature flag

* refactor

* All broked

* compile at the right time

* todo done

* all dynamic

* all dynamic in template too

* fmt

* fix ci

---------

Co-authored-by: nathaniel <[email protected]>
  • Loading branch information
louisfd and nathanielsimard authored Apr 5, 2024
1 parent 1239d9b commit f5159b6
Show file tree
Hide file tree
Showing 54 changed files with 694 additions and 738 deletions.
54 changes: 27 additions & 27 deletions burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ fn main(
Now, let's move on to the next step, which involves implementing the remaining code to launch the
kernel. The initial part entails loading the template and populating it with the appropriate
variables. The `register(name, value)` method simply replaces occurrences of `{{ name }}` in the
above WGSL code with some other string before it is compilated.
above WGSL code with some other string before it is compilated. In order to use templating
utilities, you will have to activate the `template` feature of Burn in your `cargo.toml`.

```rust, ignore
// Source the kernel written in WGSL.
Expand All @@ -172,24 +173,21 @@ kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl");
// Define our kernel type with workgroup information.
#[derive(new, Debug)]
struct FusedMatmulAddRelu<E: FloatElement> {
workgroup_size_x: usize,
workgroup_size_y: usize,
workgroup_size: WorkgroupSize,
_elem: PhantomData<E>,
}
// Implement the dynamic kernel trait for our kernel type.
impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
fn source_template(self) -> SourceTemplate {
impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
fn source(&self) -> SourceTemplate {
// Extend our raw kernel with workgroup size information using the
// `SourceTemplate` trait.
FusedMatmulAddReluRaw::source_template()
.register("workgroup_size_x", self.workgroup_size_x.to_string())
.register("workgroup_size_y", self.workgroup_size_y.to_string())
FusedMatmulAddReluRaw::new()
.source()
.register("workgroup_size_x", self.workgroup_size.x.to_string())
.register("workgroup_size_y", self.workgroup_size.y.to_string())
.register("elem", E::type_name())
}
fn id(&self) -> String {
format!("{:?}", self)
.register("int", "i32")
}
}
```
Expand All @@ -200,16 +198,16 @@ the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G, F, I> {
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G, F, I>> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
bias: FloatTensor<Self, D>,
) -> WgpuTensor<F, D> {
) -> FloatTensor<Self, D> {
// Define workgroup size, hardcoded for simplicity.
let workgroup_size_x = 16;
let workgroup_size_y = 16;
let workgroup_size = WorkgroupSize { x: 16, y: 16, z: 1 };
// Specify the size of a workgroup for this kernel
lhs.assert_is_on_same_device(&rhs);
lhs.assert_is_on_same_device(&bias);
Expand All @@ -225,7 +223,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
// Compute shape of output, while tracking number of batches.
let mut num_batches = 1;
let mut shape_out = [0; D];
for i in 0..D - 2 {
for i in shape_out.into_iter().take(D - 2) {
shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
num_batches *= shape_out[i];
}
Expand All @@ -235,29 +233,31 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
// Create a buffer for the output tensor.
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<F>());
.client
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let output = JitTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Create the kernel.
let kernel = FusedMatmulAddRelu::<F>::new(workgroup_size_x, workgroup_size_y);
let kernel = FusedMatmulAddRelu::<F>::new(workgroup_size);
// Build info buffer with tensor information needed by the kernel, such as shapes and strides.
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffer = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
// Declare the wgsl workgroup with the number of blocks in x, y and z.
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size.x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size.y as f32) as u32;
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers.
lhs.client.execute(
Box::new(DynamicKernel::new(kernel, workgroup)),
Kernel::Custom(Box::new(SourceKernel::new(
kernel,
workgroup,
workgroup_size,
))),
&[
&lhs.handle,
&rhs.handle,
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
blas-netlib = ["burn-ndarray?/blas-netlib"]
autotune = ["burn-wgpu?/autotune"]
template = ["burn-wgpu?/template"]

ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
Expand Down Expand Up @@ -129,7 +130,7 @@ rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
thiserror = { workspace = true, optional = true }
regex = { workspace = true, optional = true }
num-traits = {workspace = true, optional = true }
num-traits = { workspace = true, optional = true }

[dev-dependencies]
tempfile = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ default = ["autotune", "std", "burn-compute/default", "fusion"]
std = []
doc = ["default"]
autotune = []
template = []
fusion = ["burn-fusion"]
export_tests = [
"burn-tensor-testgen",
Expand Down
49 changes: 11 additions & 38 deletions crates/burn-jit/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::compute::{DynamicKernel, Kernel, StaticKernel, WorkGroup};
use crate::compute::{FullCompilationPhase, Kernel, WorkGroup};
use crate::element::JitElement;
use crate::gpu::Elem;
use crate::kernel::{
elemwise_workgroup, DynamicKernelSource, StaticKernelSource, WORKGROUP_DEFAULT,
};
use crate::kernel::{elemwise_workgroup, GpuComputeShaderPhase, WORKGROUP_DEFAULT};
use crate::Runtime;
use burn_compute::client::ComputeClient;
use burn_compute::server::Handle;
Expand All @@ -22,33 +20,6 @@ pub enum WorkgroupLaunch {
Custom(WorkGroup),
}

/// Execute a static kernel.
pub fn execute_static<R, K, E>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalar_elems: Option<&[E]>,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: StaticKernelSource + 'static,
R: Runtime,
E: JitElement,
{
let settings =
execute_settings::<R, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, &client);
let mut handles = settings.handles_tensors;
let workgroup = settings.workgroup;

handles.push(&settings.handle_info);
for handle in settings.handles_scalars.iter() {
handles.push(handle);
}

let kernel = Box::new(StaticKernel::<K>::new(workgroup));

client.execute(kernel, &handles);
}

pub struct Execution<'h, K, R: Runtime, Scalars> {
scalars: Scalars,
client: ComputeClient<R::Server, R::Channel>,
Expand Down Expand Up @@ -95,7 +66,7 @@ impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {

impl<'h, K, R> Execution<'h, K, R, ()>
where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
{
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
Expand Down Expand Up @@ -125,7 +96,7 @@ where

impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
E: JitElement,
{
Expand Down Expand Up @@ -160,7 +131,7 @@ where

impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
Expand All @@ -182,7 +153,7 @@ where
#[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: WorkgroupLaunch)
where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
{
execute_dynamic::<R, K, E1, E2, f32>(
Expand All @@ -200,7 +171,7 @@ where

impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
Expand Down Expand Up @@ -233,7 +204,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: DynamicKernelSource + 'static,
K: GpuComputeShaderPhase + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
Expand All @@ -250,7 +221,9 @@ fn execute_dynamic<R, K, E1, E2, E3>(
handles.push(handle);
}

let kernel: Box<dyn Kernel> = Box::new(DynamicKernel::new(kernel, workgroup));
let kernel = Kernel::JitGpu(Box::new(FullCompilationPhase::<R::Compiler, K>::new(
kernel, workgroup,
)));

client.execute(kernel, &handles);
}
Expand Down
Loading

0 comments on commit f5159b6

Please sign in to comment.