Skip to content

Commit

Permalink
Migrate wgsl index shaders to gpu representation (#1378)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 28, 2024
1 parent 330552a commit 40bf392
Show file tree
Hide file tree
Showing 13 changed files with 862 additions and 527 deletions.
4 changes: 4 additions & 0 deletions crates/burn-wgpu/src/kernel/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ mod gather;
mod repeat;
mod scatter;
mod select;
mod select_assign;
mod slice;
mod slice_assign;

pub use repeat::*;
pub use select::*;
pub use select_assign::*;
pub use slice::*;
pub use slice_assign::*;

pub(crate) use gather::*;
pub(crate) use scatter::*;
131 changes: 115 additions & 16 deletions crates/burn-wgpu/src/kernel/index/repeat.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,107 @@
use crate::{
compute::StaticKernel,
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
kernel::{self, DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
};
use std::marker::PhantomData;

kernel_wgsl!(RepeatRaw, "../../template/index/repeat.wgsl");
pub struct RepeatComputeShader {
input: Variable,
output: Variable,
dim: usize,
rank: usize,
}

#[derive(new)]
struct RepeatEagerKernel<R: Runtime, E: JitElement> {
dim: usize,
rank: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}

impl RepeatComputeShader {
pub fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;

let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.zero(Elem::UInt);

let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);

for i in 0..self.rank {
if i != self.dim {
gpu!(scope, stride_input = stride(input, i));
gpu!(scope, stride_output = stride(output, i));
gpu!(scope, shape_output = shape(output, i));

gpu!(scope, offset_local = id / stride_output);
gpu!(scope, offset_local = offset_local % shape_output);
gpu!(scope, offset_local = offset_local * stride_input);
gpu!(scope, offset_input += offset_local);
}
}

let result = scope.create_local(input.item());
gpu!(scope, result = input[offset_input]);
gpu!(scope, output[id] = result);
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for RepeatEagerKernel<R, E> {
fn source(&self) -> kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();

let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);

scope.write_global_custom(output);

RepeatComputeShader {
input,
output,
rank: self.rank,
dim: self.dim,
}
.expand(&mut scope);

let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let output = OutputInfo::Array { item };

let info = CompilationInfo {
inputs: vec![input],
outputs: vec![output],
scope,
};

let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}

fn id(&self) -> String {
format!(
"{:?}d={}r={}",
core::any::TypeId::of::<Self>(),
self.dim,
self.rank
)
}
}

pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
input: JitTensor<R, E, D1>,
Expand All @@ -32,19 +126,24 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
handle,
);

let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

let kernel = StaticKernel::<
KernelSettings<RepeatRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT));

input.client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_handle],
let kernel = RepeatEagerKernel::new(dim, D1);

execute_dynamic::<R, RepeatEagerKernel<R, E>, E>(
&[EagerHandle::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);

output
}

Expand Down
Loading

0 comments on commit 40bf392

Please sign in to comment.