Skip to content

Commit 35345de

Browse files
Feat/cube/slice (#2004)
* Refactor Variable types * Sice * Implement slice wgsl * handle lifetime correctly * Add cuda impl * Update cmma * Cleanup * Fix tests * Fix slice signature
1 parent c30ffcf commit 35345de

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+1663
-565
lines changed

crates/burn-cube/src/codegen/integrator.rs

+20-6
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,24 @@ impl KernelIntegrator {
423423
} else {
424424
item
425425
};
426-
let elem_adapted = bool_item(item);
426+
let item_adapted = bool_item(item);
427427

428428
self.output_bindings.push(Binding {
429-
item: elem_adapted,
429+
item: item_adapted,
430430
visibility: Visibility::ReadWrite,
431431
location: Location::Storage,
432432
size: None,
433433
});
434434
self.expansion.scope.write_global(
435-
Variable::Local(local, item, self.expansion.scope.depth),
436-
Variable::GlobalOutputArray(index, elem_adapted),
435+
Variable::Local {
436+
id: local,
437+
item,
438+
depth: self.expansion.scope.depth,
439+
},
440+
Variable::GlobalOutputArray {
441+
id: index,
442+
item: item_adapted,
443+
},
437444
position,
438445
);
439446
index += 1;
@@ -451,8 +458,15 @@ impl KernelIntegrator {
451458
};
452459

453460
self.expansion.scope.write_global(
454-
Variable::Local(local, item, self.expansion.scope.depth),
455-
Variable::GlobalInputArray(input, bool_item(item)),
461+
Variable::Local {
462+
id: local,
463+
item,
464+
depth: self.expansion.scope.depth,
465+
},
466+
Variable::GlobalInputArray {
467+
id: input,
468+
item: bool_item(item),
469+
},
456470
position,
457471
);
458472
}

crates/burn-cube/src/frontend/branch.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ where
2727

2828
if unroll {
2929
let start = match start.deref() {
30-
Variable::ConstantScalar(val, _) => *val as usize,
30+
Variable::ConstantScalar { value, .. } => *value as usize,
3131
_ => panic!("Only constant start can be unrolled."),
3232
};
3333
let end = match end.deref() {
34-
Variable::ConstantScalar(val, _) => *val as usize,
34+
Variable::ConstantScalar { value, .. } => *value as usize,
3535
_ => panic!("Only constant end can be unrolled."),
3636
};
3737

crates/burn-cube/src/frontend/cmma.rs

+16-10
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
//! 16,
1616
//! 16,
1717
//! 16,
18-
//! cmma::MatrixLayout::ColMajor,
18+
//! cmma::MatrixLayout::RowMajor,
1919
//! );
2020
//! let b = cmma::Matrix::<F16>::new(
2121
//! cmma::MatrixIdent::B,
2222
//! 16,
2323
//! 16,
2424
//! 16,
25-
//! cmma::MatrixLayout::RowMajor,
25+
//! cmma::MatrixLayout::ColMajor,
2626
//! );
2727
//! let c = cmma::Matrix::<F32>::new(
2828
//! cmma::MatrixIdent::Accumulator,
@@ -32,12 +32,17 @@
3232
//! cmma::MatrixLayout::Undefined,
3333
//! );
3434
//! cmma::fill::<F32>(&c, F32::new(0.0));
35-
//! cmma::load::<F16>(&a, lhs, UInt::new(16));
36-
//! cmma::load::<F16>(&b, rhs, UInt::new(16));
35+
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
36+
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
3737
//!
3838
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
3939
//!
40-
//! cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
40+
//! cmma::store::<F32>(
41+
//! out.as_slice_mut(),
42+
//! &c,
43+
//! UInt::new(16),
44+
//! cmma::MatrixLayout::RowMajor,
45+
//! );
4146
//! }
4247
//! ```
4348
@@ -49,7 +54,8 @@ use crate::{
4954
};
5055

5156
use super::{
52-
Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt,
57+
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
58+
UInt,
5359
};
5460

5561
pub use ir::{MatrixIdent, MatrixLayout};
@@ -137,7 +143,7 @@ pub fn fill_expand<C: CubeType>(
137143

138144
/// Load the matrix with the provided array using the stride.
139145
#[allow(unused_variables)]
140-
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
146+
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
141147
unexpanded!()
142148
}
143149

@@ -146,7 +152,7 @@ pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
146152
pub fn load_expand<C: CubeType>(
147153
context: &mut CubeContext,
148154
mat: MatrixExpand,
149-
value: ExpandElementTyped<Array<C>>,
155+
value: ExpandElementTyped<Slice<'static, C>>,
150156
stride: ExpandElement,
151157
) {
152158
context.register(Operation::CoopMma(ir::CoopMma::Load {
@@ -159,7 +165,7 @@ pub fn load_expand<C: CubeType>(
159165
/// Store the matrix in the given array following the given stride and layout.
160166
#[allow(unused_variables)]
161167
pub fn store<C: CubePrimitive>(
162-
output: &Array<C>,
168+
output: &mut SliceMut<'_, C>,
163169
mat: &Matrix<C>,
164170
stride: UInt,
165171
layout: MatrixLayout,
@@ -171,7 +177,7 @@ pub fn store<C: CubePrimitive>(
171177
#[allow(unused_variables)]
172178
pub fn store_expand<C: CubePrimitive>(
173179
context: &mut CubeContext,
174-
output: ExpandElementTyped<Array<C>>,
180+
output: ExpandElementTyped<SliceMut<'static, C>>,
175181
mat: MatrixExpand,
176182
stride: ExpandElement,
177183
layout: MatrixLayout,

crates/burn-cube/src/frontend/context.rs

+14-16
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ use alloc::rc::Rc;
44
use core::cell::RefCell;
55
use std::collections::HashMap;
66

7-
use super::{CubePrimitive, SharedMemoryExpand};
8-
97
#[derive(Default, Clone)]
108
pub struct VariablePool {
119
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
@@ -114,34 +112,34 @@ impl CubeContext {
114112
ExpandElement::Plain(variable)
115113
}
116114

117-
pub fn create_shared<T: CubePrimitive>(
118-
&mut self,
119-
item: Item,
120-
size: u32,
121-
) -> SharedMemoryExpand<T> {
122-
SharedMemoryExpand {
123-
val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)),
124-
}
115+
/// Create a new slice element.
116+
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
117+
let variable = self.scope.borrow_mut().create_slice(item);
118+
ExpandElement::Plain(variable)
119+
}
120+
121+
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
122+
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
125123
}
126124

127125
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
128126
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
129127
}
130128

131129
/// Obtain the index-th input
132-
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
133-
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
130+
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
131+
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
134132
}
135133

136134
/// Obtain the index-th output
137-
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
138-
let var = crate::ir::Variable::GlobalOutputArray(index, item);
135+
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
136+
let var = crate::ir::Variable::GlobalOutputArray { id, item };
139137
self.scope.borrow_mut().write_global_custom(var);
140138
ExpandElement::Plain(var)
141139
}
142140

143141
/// Obtain the index-th scalar
144-
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
145-
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
142+
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
143+
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
146144
}
147145
}

crates/burn-cube/src/frontend/element/array.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
3636
) -> <Self as CubeType>::ExpandType {
3737
let size = size.value();
3838
let size = match size {
39-
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
39+
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
4040
_ => panic!("Array need constant initialization value"),
4141
};
4242
context
@@ -55,7 +55,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
5555
) -> <Self as CubeType>::ExpandType {
5656
let size = size.value();
5757
let size = match size {
58-
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
58+
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
5959
_ => panic!("Shared memory need constant initialization value"),
6060
};
6161
context

crates/burn-cube/src/frontend/element/base.rs

+11-10
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ impl ExpandElement {
160160
pub fn can_mut(&self) -> bool {
161161
match self {
162162
ExpandElement::Managed(var) => {
163-
if let Variable::Local(_, _, _) = var.as_ref() {
163+
if let Variable::Local { .. } = var.as_ref() {
164164
Rc::strong_count(var) <= 2
165165
} else {
166166
false
@@ -201,10 +201,10 @@ impl Init for ExpandElement {
201201
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
202202

203203
match *self {
204-
Variable::GlobalScalar(_, _) => init(self),
205-
Variable::LocalScalar(_, _, _) => init(self),
206-
Variable::ConstantScalar(_, _) => init(self),
207-
Variable::Local(_, _, _) => init(self),
204+
Variable::GlobalScalar { .. } => init(self),
205+
Variable::LocalScalar { .. } => init(self),
206+
Variable::ConstantScalar { .. } => init(self),
207+
Variable::Local { .. } => init(self),
208208
// Constant should be initialized since the new variable can be mutated afterward.
209209
// And it is assumed those values are cloned.
210210
Variable::Rank
@@ -230,11 +230,12 @@ impl Init for ExpandElement {
230230
| Variable::AbsolutePosY
231231
| Variable::AbsolutePosZ => init(self),
232232
// Array types can't be copied, so we should simply return the same variable.
233-
Variable::SharedMemory(_, _, _)
234-
| Variable::GlobalInputArray(_, _)
235-
| Variable::GlobalOutputArray(_, _)
236-
| Variable::LocalArray(_, _, _, _)
237-
| Variable::Matrix(_, _) => self,
233+
Variable::SharedMemory { .. }
234+
| Variable::GlobalInputArray { .. }
235+
| Variable::GlobalOutputArray { .. }
236+
| Variable::LocalArray { .. }
237+
| Variable::Slice { .. }
238+
| Variable::Matrix { .. } => self,
238239
}
239240
}
240241
}

crates/burn-cube/src/frontend/element/cube_elem.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ impl_into_expand_element!(i64);
4141
/// Useful for Comptime
4242
impl From<UInt> for ExpandElement {
4343
fn from(value: UInt) -> Self {
44-
ExpandElement::Plain(crate::ir::Variable::ConstantScalar(
45-
value.val as f64,
46-
UInt::as_elem(),
47-
))
44+
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
45+
value: value.val as f64,
46+
elem: UInt::as_elem(),
47+
})
4848
}
4949
}

crates/burn-cube/src/frontend/element/float.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ macro_rules! impl_float {
7575
}
7676

7777
fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
78-
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
78+
let new_var = Variable::ConstantScalar {
79+
value: val as f64,
80+
elem: Self::as_elem(),
81+
};
7982
ExpandElement::Plain(new_var)
8083
}
8184

crates/burn-cube/src/frontend/element/int.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ macro_rules! impl_int {
4949
}
5050

5151
fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
52-
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
52+
let new_var = Variable::ConstantScalar {
53+
value: val as f64,
54+
elem: Self::as_elem(),
55+
};
5356
ExpandElement::Plain(new_var)
5457
}
5558

crates/burn-cube/src/frontend/element/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod float;
77
mod int;
88
mod numeric;
99
mod shared_memory;
10+
mod slice;
1011
mod tensor;
1112
mod uint;
1213
mod vectorized;
@@ -19,6 +20,7 @@ pub use float::*;
1920
pub use int::*;
2021
pub use numeric::*;
2122
pub use shared_memory::*;
23+
pub use slice::*;
2224
pub use tensor::*;
2325
pub use uint::*;
2426
pub use vectorized::*;

crates/burn-cube/src/frontend/element/numeric.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ pub trait Numeric:
4747

4848
/// Expand version of from_int
4949
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
50-
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
50+
let new_var = Variable::ConstantScalar {
51+
value: val as f64,
52+
elem: Self::as_elem(),
53+
};
5154
ExpandElement::Plain(new_var)
5255
}
5356

crates/burn-cube/src/frontend/element/shared_memory.rs

+10-19
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,21 @@ use crate::{
55
ir::Item,
66
};
77

8-
use super::{ExpandElement, Init, UInt};
8+
use super::{ExpandElementTyped, Init, UInt};
99

1010
#[derive(Clone, Copy)]
1111
pub struct SharedMemory<T: CubeType> {
1212
_val: PhantomData<T>,
1313
}
1414

15-
#[derive(Clone)]
16-
pub struct SharedMemoryExpand<T: CubePrimitive> {
17-
pub val: <T as CubeType>::ExpandType,
18-
}
19-
20-
impl<T: CubePrimitive> From<SharedMemoryExpand<T>> for ExpandElement {
21-
fn from(shared_memory_expand: SharedMemoryExpand<T>) -> Self {
22-
shared_memory_expand.val
23-
}
24-
}
25-
26-
impl<T: CubePrimitive> Init for SharedMemoryExpand<T> {
15+
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
2716
fn init(self, _context: &mut CubeContext) -> Self {
2817
self
2918
}
3019
}
3120

3221
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
33-
type ExpandType = SharedMemoryExpand<T>;
22+
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
3423
}
3524

3625
impl<T: CubePrimitive + Clone> SharedMemory<T> {
@@ -44,10 +33,11 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
4433
) -> <Self as CubeType>::ExpandType {
4534
let size = size.value();
4635
let size = match size {
47-
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
36+
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
4837
_ => panic!("Shared memory need constant initialization value"),
4938
};
50-
context.create_shared(Item::new(T::as_elem()), size)
39+
let var = context.create_shared(Item::new(T::as_elem()), size);
40+
ExpandElementTyped::new(var)
5141
}
5242

5343
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
@@ -61,12 +51,13 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
6151
) -> <Self as CubeType>::ExpandType {
6252
let size = size.value();
6353
let size = match size {
64-
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
54+
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
6555
_ => panic!("Shared memory need constant initialization value"),
6656
};
67-
context.create_shared(
57+
let var = context.create_shared(
6858
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
6959
size,
70-
)
60+
);
61+
ExpandElementTyped::new(var)
7162
}
7263
}

0 commit comments

Comments
 (0)