Skip to content

Commit 76fe0ed

Browse files
authored
Refactor/cube/vectorization (#1781)
1 parent 499ff0d commit 76fe0ed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+433
-277
lines changed

crates/burn-cube-macros/src/analysis.rs

+9
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ impl CodeAnalysisBuilder {
239239
}
240240
syn::Expr::Break(_) => {}
241241
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
242+
syn::Expr::Array(expr) => {
243+
for element in expr.elems.iter() {
244+
match element {
245+
syn::Expr::Lit(_) => {}
246+
_ => todo!("Analysis: only array of literals is supported"),
247+
}
248+
}
249+
}
250+
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
242251
_ => todo!("Analysis: unsupported expr {expr:?}"),
243252
}
244253
}

crates/burn-cube-macros/src/codegen/base.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use super::{
66
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
77
function::{codegen_call, codegen_closure, codegen_expr_method_call},
88
operation::codegen_binary,
9-
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
9+
variable::{
10+
codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local,
11+
codegen_path_rhs,
12+
},
1013
};
1114

1215
/// Codegen for a statement (generally one line)
@@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block(
5962
codegen_block(&block.block, loop_level, variable_analyses)
6063
}
6164

65+
pub(crate) fn codegen_ref(
66+
reference: &syn::ExprReference,
67+
loop_level: usize,
68+
variable_analyses: &mut CodeAnalysis,
69+
) -> TokenStream {
70+
let inner = codegen_expr(&reference.expr, loop_level, variable_analyses);
71+
quote::quote! { & #inner }
72+
}
73+
6274
/// Codegen for expressions
6375
/// There are many variants of expression, treated differently
6476
pub(crate) fn codegen_expr(
@@ -84,6 +96,8 @@ pub(crate) fn codegen_expr(
8496
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
8597
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
8698
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
99+
syn::Expr::Array(array) => codegen_array_lit(array),
100+
syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses),
87101
_ => panic!("Codegen: Unsupported {:?}", expr),
88102
}
89103
}

crates/burn-cube-macros/src/codegen/function.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@ pub(crate) fn codegen_closure(
3434
}
3535

3636
/// Codegen for a function call
37-
/// Supports:
38-
/// func()
39-
/// func::<T>()
40-
/// T::func()
41-
///
42-
/// Should map:
37+
/// Maps
4338
/// [A[::<...>]?::]^* func[::<...>] (args)
4439
/// to
4540
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)

crates/burn-cube-macros/src/codegen/variable.rs

+13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
1919
}
2020
}
2121

22+
/// Codegen for arrays of literals
23+
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
24+
let mut tokens = quote::quote! {};
25+
for element in array.elems.iter() {
26+
let token = match element {
27+
syn::Expr::Lit(lit) => codegen_lit(lit),
28+
_ => todo!("Codegen: Only arrays of literals are supported"),
29+
};
30+
tokens.extend(quote::quote! { #token, });
31+
}
32+
quote::quote! { [ #tokens ] }
33+
}
34+
2235
/// Codegen for a local declaration (let ...)
2336
/// Supports:
2437
/// let x = ...

crates/burn-cube/src/codegen/compilation.rs

+7-14
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings {
8181
}
8282

8383
match self.vectorization {
84-
Some(vectorization) => match vectorization {
85-
Vectorization::Vec4 => f.write_str("v4"),
86-
Vectorization::Vec3 => f.write_str("v3"),
87-
Vectorization::Vec2 => f.write_str("v2"),
88-
Vectorization::Scalar => f.write_str("v1"),
89-
}?,
84+
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
9085
None => f.write_str("vn")?,
9186
};
9287

@@ -154,7 +149,7 @@ impl InputInfo {
154149
item,
155150
visibility: _,
156151
} => *item,
157-
InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem),
152+
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
158153
}
159154
}
160155
}
@@ -252,7 +247,7 @@ impl Compilation {
252247
named.push((
253248
"info".to_string(),
254249
Binding {
255-
item: Item::Scalar(Elem::UInt),
250+
item: Item::new(Elem::UInt),
256251
visibility: Visibility::Read,
257252
location: Location::Storage,
258253
size: None, // We avoid putting the length here since it will force a new kernel
@@ -300,7 +295,7 @@ impl Compilation {
300295
self.named_bindings.push((
301296
format!("scalars_{}", elem),
302297
Binding {
303-
item: Item::Scalar(elem),
298+
item: Item::new(elem),
304299
visibility: Visibility::Read,
305300
location: Location::Storage,
306301
size: Some(size),
@@ -440,11 +435,9 @@ impl Compilation {
440435
}
441436

442437
fn bool_item(ty: Item) -> Item {
443-
match ty {
444-
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
445-
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
446-
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
447-
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
438+
Item {
439+
elem: bool_elem(ty.elem),
440+
vectorization: ty.vectorization,
448441
}
449442
}
450443

crates/burn-cube/src/codegen/dialect/branch.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl RangeLoop {
9494
func: F,
9595
) {
9696
let mut scope = parent_scope.child();
97-
let index_ty = Item::Scalar(Elem::UInt);
97+
let index_ty = Item::new(Elem::UInt);
9898
let i = scope.create_local_undeclared(index_ty);
9999

100100
func(i, &mut scope);

crates/burn-cube/src/codegen/dialect/procedure/assign.rs

+21-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization};
1+
use crate::{
2+
branch::range,
3+
codegen::dialect::{macros::cpa, Scope, Variable, Vectorization},
4+
};
25
use serde::{Deserialize, Serialize};
36

47
/// Assign value to a variable based on a given condition.
@@ -19,14 +22,15 @@ impl ConditionalAssign {
1922
let rhs = self.rhs;
2023
let out = self.out;
2124

22-
let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() {
23-
Item::Scalar(_) => var,
24-
_ => {
25-
let out = scope.create_local(var.item().elem());
26-
cpa!(scope, out = var[index]);
27-
out
28-
}
29-
};
25+
let index_var =
26+
|scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
27+
true => var,
28+
false => {
29+
let out = scope.create_local(var.item().elem());
30+
cpa!(scope, out = var[index]);
31+
out
32+
}
33+
};
3034

3135
let mut assign_index = |index: usize| {
3236
let cond = index_var(scope, cond, index);
@@ -42,29 +46,20 @@ impl ConditionalAssign {
4246
}));
4347
};
4448

45-
match out.item() {
46-
Item::Vec4(_) => {
47-
assign_index(0);
48-
assign_index(1);
49-
assign_index(2);
50-
assign_index(3);
51-
}
52-
Item::Vec3(_) => {
53-
assign_index(0);
54-
assign_index(1);
55-
assign_index(2);
56-
}
57-
Item::Vec2(_) => {
58-
assign_index(0);
59-
assign_index(1);
60-
}
61-
Item::Scalar(_) => {
49+
let vectorization = out.item().vectorization;
50+
match vectorization == 1 {
51+
true => {
6252
cpa!(scope, if (cond).then(|scope| {
6353
cpa!(scope, out = lhs);
6454
}).else(|scope| {
6555
cpa!(scope, out = rhs);
6656
}));
6757
}
58+
false => {
59+
for i in range(0u32, vectorization as u32, true) {
60+
assign_index(i);
61+
}
62+
}
6863
};
6964
}
7065

crates/burn-cube/src/codegen/dialect/procedure/index.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ impl CheckedIndex {
1919
let lhs = self.lhs;
2020
let rhs = self.rhs;
2121
let out = self.out;
22-
let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt));
23-
let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool));
22+
let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt));
23+
let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool));
2424

2525
cpa!(scope, array_len = len(lhs));
2626
cpa!(scope, inside_bound = rhs < array_len);
@@ -56,8 +56,8 @@ impl CheckedIndexAssign {
5656
let lhs = self.lhs;
5757
let rhs = self.rhs;
5858
let out = self.out;
59-
let array_len = scope.create_local(Item::Scalar(Elem::UInt));
60-
let inside_bound = scope.create_local(Item::Scalar(Elem::Bool));
59+
let array_len = scope.create_local(Item::new(Elem::UInt));
60+
let inside_bound = scope.create_local(Item::new(Elem::Bool));
6161

6262
cpa!(scope, array_len = len(out));
6363
cpa!(scope, inside_bound = lhs < array_len);

crates/burn-cube/src/codegen/dialect/procedure/read.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout {
140140
#[allow(missing_docs)]
141141
pub fn expand(self, scope: &mut Scope) {
142142
let layout = self.layout;
143-
let index_item_ty = Item::Scalar(Elem::UInt);
143+
let index_item_ty = Item::new(Elem::UInt);
144144
let offset_ref = self.position;
145145
let zero: Variable = 0u32.into();
146-
let vectorization_factor: Variable = match self.tensors[0].item() {
147-
Item::Vec4(_) => 4u32,
148-
Item::Vec3(_) => 3u32,
149-
Item::Vec2(_) => 2u32,
150-
Item::Scalar(_) => 1u32,
151-
}
152-
.into();
153-
146+
let vectorization_factor: u8 = self.tensors[0].item().vectorization;
147+
let vectorization_factor: Variable = (vectorization_factor as u32).into();
154148
for index in self.indexes.iter() {
155149
cpa!(scope, index = zero);
156150
}

crates/burn-cube/src/codegen/dialect/scope.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,9 @@ impl Scope {
336336
position: Variable,
337337
) -> Variable {
338338
let item_global = match item.elem() {
339-
Elem::Bool => match item {
340-
Item::Vec4(_) => Item::Vec4(Elem::UInt),
341-
Item::Vec3(_) => Item::Vec3(Elem::UInt),
342-
Item::Vec2(_) => Item::Vec2(Elem::UInt),
343-
Item::Scalar(_) => Item::Scalar(Elem::UInt),
339+
Elem::Bool => Item {
340+
elem: Elem::UInt,
341+
vectorization: item.vectorization,
344342
},
345343
_ => item,
346344
};

crates/burn-cube/src/codegen/dialect/shader.rs

+21-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::Scope;
1+
use super::{Scope, Vectorization};
22
use crate::WORKGROUP_DEFAULT;
33
use serde::{Deserialize, Serialize};
44
use std::fmt::Display;
@@ -44,7 +44,7 @@ pub enum Elem {
4444

4545
impl From<Elem> for Item {
4646
fn from(val: Elem) -> Self {
47-
Item::Scalar(val)
47+
Item::new(val)
4848
}
4949
}
5050

@@ -81,22 +81,30 @@ impl Display for Elem {
8181
}
8282

8383
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)]
84-
#[allow(missing_docs)]
85-
pub enum Item {
86-
Vec4(Elem),
87-
Vec3(Elem),
88-
Vec2(Elem),
89-
Scalar(Elem),
84+
pub struct Item {
85+
pub elem: Elem,
86+
pub vectorization: Vectorization,
9087
}
9188

9289
impl Item {
9390
/// Fetch the elem of the item.
9491
pub fn elem(&self) -> Elem {
95-
match self {
96-
Self::Vec4(elem) => *elem,
97-
Self::Vec3(elem) => *elem,
98-
Self::Vec2(elem) => *elem,
99-
Self::Scalar(elem) => *elem,
92+
self.elem
93+
}
94+
95+
/// Create a new item without vectorization
96+
pub fn new(elem: Elem) -> Self {
97+
Self {
98+
elem,
99+
vectorization: 1,
100+
}
101+
}
102+
103+
/// Create a new item with vectorization
104+
pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
105+
Self {
106+
elem,
107+
vectorization,
100108
}
101109
}
102110
}

crates/burn-cube/src/codegen/dialect/variable.rs

+21-21
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,30 @@ impl Variable {
6969
match self {
7070
Variable::GlobalInputArray(_, item) => *item,
7171
Variable::GlobalOutputArray(_, item) => *item,
72-
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
72+
Variable::GlobalScalar(_, elem) => Item::new(*elem),
7373
Variable::Local(_, item, _) => *item,
74-
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
75-
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
74+
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
75+
Variable::ConstantScalar(_, elem) => Item::new(*elem),
7676
Variable::SharedMemory(_, item, _) => *item,
7777
Variable::LocalArray(_, item, _, _) => *item,
78-
Variable::Id => Item::Scalar(Elem::UInt),
79-
Variable::Rank => Item::Scalar(Elem::UInt),
80-
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
81-
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
82-
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
83-
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
84-
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
85-
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
86-
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
87-
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
88-
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
89-
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
90-
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
91-
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
92-
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
93-
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
94-
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
95-
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
78+
Variable::Id => Item::new(Elem::UInt),
79+
Variable::Rank => Item::new(Elem::UInt),
80+
Variable::LocalInvocationIndex => Item::new(Elem::UInt),
81+
Variable::LocalInvocationIdX => Item::new(Elem::UInt),
82+
Variable::LocalInvocationIdY => Item::new(Elem::UInt),
83+
Variable::LocalInvocationIdZ => Item::new(Elem::UInt),
84+
Variable::WorkgroupIdX => Item::new(Elem::UInt),
85+
Variable::WorkgroupIdY => Item::new(Elem::UInt),
86+
Variable::WorkgroupIdZ => Item::new(Elem::UInt),
87+
Variable::GlobalInvocationIdX => Item::new(Elem::UInt),
88+
Variable::GlobalInvocationIdY => Item::new(Elem::UInt),
89+
Variable::GlobalInvocationIdZ => Item::new(Elem::UInt),
90+
Variable::WorkgroupSizeX => Item::new(Elem::UInt),
91+
Variable::WorkgroupSizeY => Item::new(Elem::UInt),
92+
Variable::WorkgroupSizeZ => Item::new(Elem::UInt),
93+
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
94+
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
95+
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
9696
}
9797
}
9898
}

0 commit comments

Comments
 (0)