Skip to content

Commit 843dd49

Browse files
[Refactor] Just-In-Time Compilation Pipeline (#1313)
1 parent 2428723 commit 843dd49

37 files changed

+2600
-1677
lines changed

burn-wgpu/src/codegen/compilation.rs

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
use super::dialect::gpu;
2+
use crate::codegen::dialect::gpu::{
3+
Binding, ComputeShader, Elem, Item, Location, Variable, Vectorization, Visibility,
4+
WorkgroupSize,
5+
};
6+
7+
/// The compilation struct allows you to create a [compute shader](ComputeShader) based on
8+
/// [compilation info](CompilationInfo) and [compilation settings](CompilationSettings).
9+
#[derive(Clone)]
10+
pub struct Compilation {
11+
info: CompilationInfo,
12+
input_bindings: Vec<Binding>,
13+
output_bindings: Vec<Binding>,
14+
named_bindings: Vec<(String, Binding)>,
15+
}
16+
17+
/// The information necessary to compile a [compute shader](ComputeShader).
18+
#[derive(Clone)]
19+
pub struct CompilationInfo {
20+
pub inputs: Vec<InputInfo>,
21+
pub outputs: Vec<OutputInfo>,
22+
pub scope: gpu::Scope,
23+
pub mappings: Vec<InplaceMapping>,
24+
}
25+
26+
/// Simply indicate the output that can be replaced by the input.
27+
#[derive(new, Clone, Copy)]
28+
pub struct InplaceMapping {
29+
/// Input position.
30+
pub pos_input: usize,
31+
/// Output position.
32+
pub pos_output: usize,
33+
}
34+
35+
#[derive(Default)]
36+
pub struct CompilationSettings {
37+
vectorization: Vectorization,
38+
inplace_available: bool,
39+
workgroup_size: WorkgroupSize,
40+
}
41+
42+
impl CompilationSettings {
43+
/// Compile the shader with vectorization enabled.
44+
#[allow(dead_code)]
45+
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
46+
self.vectorization = vectorization;
47+
self
48+
}
49+
/// Compile the shader with inplace enabled.
50+
///
51+
/// Notes:
52+
///
53+
/// This won't guarantee that the shader will use input arrays as outputs, since it is only
54+
/// possible when [inplace mappings](InplaceMapping) are provided as [compilation info](CompilationInfo)
55+
pub fn inplace(mut self, available: bool) -> Self {
56+
self.inplace_available = available;
57+
self
58+
}
59+
/// Set the grid size.
60+
#[allow(dead_code)] // Only used for fusion for now.
61+
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
62+
self.workgroup_size = workgroup_size;
63+
self
64+
}
65+
}
66+
67+
/// Information related to an input.
68+
#[derive(Clone)]
69+
pub enum InputInfo {
70+
Array { item: Item, visibility: Visibility },
71+
Scalar { elem: Elem, size: usize },
72+
}
73+
74+
/// Information related to an output.
75+
#[derive(Clone)]
76+
pub enum OutputInfo {
77+
/// Write the local variable to a new array.
78+
///
79+
/// This will create a new binding in the [compute shader](ComputeShader).
80+
Array { item: Item, local: u16 },
81+
/// Write the local variable to an existing input binding.
82+
Input { item: Item, input: u16, local: u16 },
83+
}
84+
85+
impl Compilation {
86+
/// Starts a new compilation.
87+
pub fn new(info: CompilationInfo) -> Self {
88+
Self {
89+
info,
90+
input_bindings: Default::default(),
91+
output_bindings: Default::default(),
92+
named_bindings: Default::default(),
93+
}
94+
}
95+
96+
/// Performs the compilation with the provided [settings](CompilationSettings).
97+
pub fn compile(mut self, settings: CompilationSettings) -> ComputeShader {
98+
self.info.scope.vectorize(settings.vectorization);
99+
100+
self.register_inputs(&settings);
101+
self.register_outputs(&settings);
102+
103+
let inputs = self.input_bindings;
104+
let outputs = self.output_bindings;
105+
let mut named = Vec::with_capacity(2);
106+
107+
named.push((
108+
"info".to_string(),
109+
Binding {
110+
item: Item::Scalar(Elem::UInt),
111+
visibility: Visibility::Read,
112+
location: Location::Storage,
113+
size: None, // We avoid putting the length here since it will force a new kernel
114+
// for each tensor rank.
115+
},
116+
));
117+
118+
for (name, binding) in self.named_bindings.into_iter() {
119+
named.push((name, binding));
120+
}
121+
122+
ComputeShader {
123+
inputs,
124+
outputs,
125+
named,
126+
workgroup_size: settings.workgroup_size,
127+
body: self.info.scope,
128+
num_workgroups: true,
129+
global_invocation_id: true,
130+
}
131+
}
132+
133+
fn register_inputs(&mut self, settings: &CompilationSettings) {
134+
for input in self.info.inputs.drain(..) {
135+
match input {
136+
InputInfo::Array { item, visibility } => {
137+
let item = item.vectorize(settings.vectorization);
138+
139+
self.input_bindings.push(Binding {
140+
item: bool_item(item),
141+
visibility,
142+
location: Location::Storage,
143+
size: None,
144+
});
145+
}
146+
InputInfo::Scalar { elem, size } => {
147+
let elem = bool_elem(elem);
148+
149+
self.named_bindings.push((
150+
format!("scalars_{}", elem),
151+
Binding {
152+
item: Item::Scalar(elem),
153+
visibility: Visibility::Read,
154+
location: Location::Storage,
155+
size: Some(size),
156+
},
157+
));
158+
}
159+
}
160+
}
161+
}
162+
163+
fn register_outputs(&mut self, settings: &CompilationSettings) {
164+
let mut index = 0;
165+
166+
if settings.inplace_available {
167+
let mut mappings = Vec::new();
168+
core::mem::swap(&mut self.info.mappings, &mut mappings);
169+
170+
for mapping in mappings {
171+
self.register_inplace_mapping(mapping);
172+
}
173+
}
174+
175+
for array in self.info.outputs.drain(..) {
176+
match array {
177+
OutputInfo::Array { item, local } => {
178+
let item = item.vectorize(settings.vectorization);
179+
let elem_adapted = bool_item(item);
180+
181+
self.output_bindings.push(Binding {
182+
item: elem_adapted,
183+
visibility: Visibility::ReadWrite,
184+
location: Location::Storage,
185+
size: None,
186+
});
187+
self.info.scope.write_global(
188+
Variable::Local(local, item, self.info.scope.depth),
189+
Variable::GlobalOutputArray(index, elem_adapted),
190+
);
191+
index += 1;
192+
}
193+
OutputInfo::Input { item, input, local } => {
194+
let item = item.vectorize(settings.vectorization);
195+
196+
self.info.scope.write_global(
197+
Variable::Local(local, item, self.info.scope.depth),
198+
Variable::GlobalInputArray(input, bool_item(item)),
199+
);
200+
}
201+
}
202+
}
203+
}
204+
205+
fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
206+
let output = match self.info.outputs.get_mut(mapping.pos_output) {
207+
Some(output) => output,
208+
None => return, // No output to update.
209+
};
210+
211+
let (item, local) = match output {
212+
OutputInfo::Array { item, local } => (item, local),
213+
OutputInfo::Input {
214+
item: _,
215+
input: _,
216+
local: _,
217+
} => return, // Output already updated.
218+
};
219+
220+
let item = match self.input_bindings.get_mut(mapping.pos_input) {
221+
Some(binding) => {
222+
// Update input visibility.
223+
binding.visibility = Visibility::ReadWrite;
224+
// Inputs modified inplace should be read without any specified layout.
225+
self.info
226+
.scope
227+
.update_read(mapping.pos_input as u16, gpu::ReadingStrategy::Plain);
228+
229+
// Use the same item as the input.
230+
//
231+
// The output can be different (i.e inplace boolean operations on float bindings).
232+
binding.item
233+
}
234+
None => *item,
235+
};
236+
237+
// Update the output.
238+
*output = OutputInfo::Input {
239+
item,
240+
input: mapping.pos_input as u16,
241+
local: *local,
242+
};
243+
}
244+
}
245+
246+
fn bool_item(ty: Item) -> Item {
247+
match ty {
248+
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
249+
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
250+
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
251+
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
252+
}
253+
}
254+
255+
fn bool_elem(elem: Elem) -> Elem {
256+
match elem {
257+
// U32 are used for bool tensors
258+
Elem::Bool => Elem::UInt,
259+
_ => elem,
260+
}
261+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use super::{
2+
gpu, Elem, Item, Metadata, Operator, ReadGlobalAlgo, ReadGlobalWithLayoutAlgo, Scope, Variable,
3+
};
4+
use crate::codegen::dialect::gpu::BinaryOperator;
5+
6+
impl ReadGlobalAlgo {
7+
pub fn expand(self, scope: &mut Scope) {
8+
scope.register(Operator::Index(BinaryOperator {
9+
lhs: self.global,
10+
rhs: Variable::Id,
11+
out: self.out,
12+
}));
13+
}
14+
}
15+
16+
impl ReadGlobalWithLayoutAlgo {
17+
pub fn expand(self, scope: &mut Scope) {
18+
let out = self.out;
19+
let tensor = self.global;
20+
let layout = self.layout;
21+
let index_item_ty = Item::Scalar(Elem::UInt);
22+
let index_local = scope.create_local(index_item_ty);
23+
let zero: Variable = 0u32.into();
24+
let id = Variable::Id;
25+
let offset: Variable = match self.global.item() {
26+
Item::Vec4(_) => 4u32,
27+
Item::Vec3(_) => 3u32,
28+
Item::Vec2(_) => 2u32,
29+
Item::Scalar(_) => 1u32,
30+
}
31+
.into();
32+
33+
gpu!(scope, index_local = zero);
34+
gpu!(
35+
scope,
36+
range(zero, Variable::Rank).for_each(|i, scope| {
37+
let stride = scope.create_local(index_item_ty);
38+
let stride_layout = scope.create_local(index_item_ty);
39+
let shape = scope.create_local(index_item_ty);
40+
let tmp = scope.create_local(index_item_ty);
41+
42+
gpu!(scope, stride = stride(tensor, i));
43+
gpu!(scope, shape = shape(tensor, i));
44+
gpu!(scope, stride_layout = stride(layout, i));
45+
46+
gpu!(scope, tmp = id * offset);
47+
gpu!(scope, tmp = tmp / stride_layout);
48+
gpu!(scope, tmp = tmp % shape);
49+
gpu!(scope, tmp = tmp * stride);
50+
gpu!(scope, index_local = index_local + tmp);
51+
})
52+
);
53+
54+
gpu!(scope, index_local = index_local / offset);
55+
gpu!(scope, out = tensor[index_local]);
56+
}
57+
}

burn-wgpu/src/codegen/dialect/gpu/body.rs

-7
This file was deleted.

0 commit comments

Comments
 (0)