Skip to content

Commit b429cc3

Browse files
Splitted the JIT stuff from the Wgpu stuff (#1417)
1 parent 3ff6e71 commit b429cc3

File tree

210 files changed

+3791
-4746
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

210 files changed

+3791
-4746
lines changed

Cargo.lock

+24-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-jit/Cargo.toml

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
[package]
2+
authors = ["nathanielsimard <[email protected]>"]
3+
categories = ["science"]
4+
description = "Generic backend that can be compiled just-in-time to any shader language target"
5+
edition.workspace = true
6+
keywords = ["deep-learning", "machine-learning", "gpu"]
7+
license.workspace = true
8+
name = "burn-jit"
9+
readme.workspace = true
10+
repository = "https://github.com/tracel-ai/burn/tree/main/burn-jit"
11+
version.workspace = true
12+
13+
[features]
14+
default = ["autotune", "std", "burn-compute/default", "fusion"]
15+
std = []
16+
doc = ["default"]
17+
autotune = []
18+
fusion = ["burn-fusion"]
19+
export_tests = [
20+
"burn-tensor-testgen",
21+
"serial_test",
22+
"burn-autodiff/export_tests",
23+
"burn-tensor/export_tests",
24+
"burn-ndarray",
25+
"fusion",
26+
]
27+
28+
[dependencies]
29+
burn-common = { path = "../burn-common", version = "0.13.0" }
30+
burn-tensor = { path = "../burn-tensor", version = "0.13.0" }
31+
burn-fusion = { path = "../burn-fusion", version = "0.13.0", optional = true }
32+
33+
bytemuck = { workspace = true }
34+
derive-new = { workspace = true }
35+
log = { workspace = true }
36+
num-traits = { workspace = true }
37+
rand = { workspace = true }
38+
spin = { workspace = true }
39+
40+
# Template
41+
serde = { workspace = true }
42+
text_placeholder = { workspace = true, features = ["struct_context"] }
43+
44+
hashbrown = { workspace = true }
45+
burn-compute = { path = "../burn-compute", version = "0.13.0", default-features = false, features = [
46+
"channel-mutex",
47+
"std",
48+
] }
49+
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", optional = true }
50+
51+
# When exporting tests
52+
serial_test = { workspace = true, optional = true }
53+
burn-autodiff = { path = "../burn-autodiff", version = "0.13.0", default-features = false, optional = true }
54+
burn-ndarray = { path = "../burn-ndarray", version = "0.13.0", optional = true }
55+
56+
[package.metadata.docs.rs]
57+
features = ["doc"]
File renamed without changes.
File renamed without changes.

crates/burn-jit/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Burn JIT Backend
2+
3+
Generic backend that can be compiled just-in-time (JIT) to any shader language target
4+
In progress: At the moment, only WGSL compilation is supported, and some kernels still rely on pure WGSL

crates/burn-wgpu/src/backend.rs crates/burn-jit/src/backend.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{marker::PhantomData, sync::Mutex};
55

66
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
77

8-
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
8+
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
99
#[derive(new)]
1010
pub struct JitBackend<R: Runtime> {
1111
_runtime: PhantomData<R>,

crates/burn-wgpu/src/codegen/compiler.rs crates/burn-jit/src/codegen/compiler.rs

+8
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@ use super::dialect::gpu;
22
use crate::{FloatElement, IntElement};
33
use std::fmt::Display;
44

5+
/// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be
6+
/// formatted into tokens.
57
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
8+
/// The representation for the compiled code.
69
type Representation: Display;
10+
/// The float element type used for compilation.
711
type Float: FloatElement;
12+
/// The int element type used for compilation.
813
type Int: IntElement;
14+
/// The compiler that can be used to generate full precision shaders.
915
type FullPrecisionCompiler: Compiler<
1016
Representation = Self::Representation,
1117
Float = f32,
1218
Int = i32,
1319
>;
1420

21+
/// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation.
1522
fn compile(shader: gpu::ComputeShader) -> Self::Representation;
23+
/// The size of the given element in bytes.
1624
fn elem_size(elem: gpu::Elem) -> usize;
1725
}

crates/burn-wgpu/src/codegen/dialect/gpu/branch.rs crates/burn-jit/src/codegen/dialect/gpu/branch.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,37 @@ use serde::{Deserialize, Serialize};
44
/// All branching types.
55
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66
pub enum Branch {
7-
// An if statement.
7+
/// An if statement.
88
If(If),
9-
// An if else statement.
9+
/// An if else statement.
1010
IfElse(IfElse),
11-
// A range loop.
11+
/// A range loop.
1212
RangeLoop(RangeLoop),
13-
// A loop.
13+
/// A loop.
1414
Loop(Loop),
15-
// A return statement.
15+
/// A return statement.
1616
Return,
17-
// A break statement.
17+
/// A break statement.
1818
Break,
1919
}
2020

2121
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22+
#[allow(missing_docs)]
2223
pub struct If {
2324
pub cond: Variable,
2425
pub scope: Scope,
2526
}
2627

2728
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29+
#[allow(missing_docs)]
2830
pub struct IfElse {
2931
pub cond: Variable,
3032
pub scope_if: Scope,
3133
pub scope_else: Scope,
3234
}
3335

3436
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37+
#[allow(missing_docs)]
3538
pub struct RangeLoop {
3639
pub i: Variable,
3740
pub start: Variable,
@@ -40,6 +43,7 @@ pub struct RangeLoop {
4043
}
4144

4245
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46+
#[allow(missing_docs)]
4347
pub struct Loop {
4448
pub scope: Scope,
4549
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
mod branch;
2+
mod macros;
3+
mod operation;
4+
mod procedure;
5+
mod processing;
6+
mod scope;
7+
mod shader;
8+
mod synchronization;
9+
mod variable;
10+
mod vectorization;
11+
12+
pub use branch::*;
13+
pub use operation::*;
14+
pub use procedure::*;
15+
pub use scope::*;
16+
pub use shader::*;
17+
pub use synchronization::*;
18+
pub use variable::*;
19+
pub use vectorization::*;
20+
21+
pub(crate) use macros::gpu;

crates/burn-wgpu/src/codegen/dialect/gpu/operation.rs crates/burn-jit/src/codegen/dialect/gpu/operation.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};
1010
///
1111
/// [Procedure] expansions can safely use all operation variants.
1212
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13-
#[allow(dead_code)] // Some variants might not be used with different flags
13+
#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
1414
pub enum Operation {
1515
Operator(Operator),
1616
Procedure(Procedure),
@@ -21,7 +21,7 @@ pub enum Operation {
2121

2222
/// All operators that can be used in a GPU compute shader.
2323
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24-
#[allow(dead_code)] // Some variants might not be used with different flags
24+
#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
2525
pub enum Operator {
2626
Add(BinaryOperator),
2727
Sub(BinaryOperator),
@@ -57,6 +57,7 @@ pub enum Operator {
5757

5858
/// All metadata that can be access in a shader.
5959
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60+
#[allow(missing_docs)]
6061
pub enum Metadata {
6162
/// The stride of an array at the given dimension.
6263
Stride {
@@ -77,19 +78,22 @@ pub enum Metadata {
7778
}
7879

7980
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81+
#[allow(missing_docs)]
8082
pub struct BinaryOperator {
8183
pub lhs: Variable,
8284
pub rhs: Variable,
8385
pub out: Variable,
8486
}
8587

8688
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89+
#[allow(missing_docs)]
8790
pub struct UnaryOperator {
8891
pub input: Variable,
8992
pub out: Variable,
9093
}
9194

9295
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96+
#[allow(missing_docs)]
9397
pub struct ClampOperator {
9498
pub input: Variable,
9599
pub min_value: Variable,
@@ -98,11 +102,13 @@ pub struct ClampOperator {
98102
}
99103

100104
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
105+
#[allow(missing_docs)]
101106
pub struct ReadGlobalOperator {
102107
pub variable: Variable,
103108
}
104109

105110
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
111+
#[allow(missing_docs)]
106112
pub struct ReadGlobalWithLayoutOperator {
107113
pub variable: Variable,
108114
pub tensor_read_pos: usize,

crates/burn-wgpu/src/codegen/dialect/gpu/procedure/assign.rs crates/burn-jit/src/codegen/dialect/gpu/procedure/assign.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
33

44
/// Assign value to a variable based on a given condition.
55
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
6+
#[allow(missing_docs)]
67
pub struct ConditionalAssign {
78
pub cond: Variable,
89
pub lhs: Variable,
@@ -11,6 +12,7 @@ pub struct ConditionalAssign {
1112
}
1213

1314
impl ConditionalAssign {
15+
#[allow(missing_docs)]
1416
pub fn expand(self, scope: &mut Scope) {
1517
let cond = self.cond;
1618
let lhs = self.lhs;

crates/burn-wgpu/src/codegen/dialect/gpu/procedure/base.rs crates/burn-jit/src/codegen/dialect/gpu/procedure/base.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
77
/// Tensor operations that can't be executed with a simple [operator](super::super::Operator) should use a
88
/// procedure.
99
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10+
#[allow(missing_docs)]
1011
pub enum Procedure {
1112
ReadGlobalWithLayout(ReadGlobalWithLayout),
1213
IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout),
@@ -16,7 +17,7 @@ pub enum Procedure {
1617
}
1718

1819
impl Procedure {
19-
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
20+
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
2021
match self {
2122
Procedure::ReadGlobalWithLayout(op) => {
2223
Procedure::ReadGlobalWithLayout(op.vectorize(vectorization))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mod assign;
2+
mod base;
3+
mod read;
4+
mod write;
5+
6+
pub use assign::*;
7+
pub use base::*;
8+
pub use read::*;
9+
pub use write::*;

crates/burn-wgpu/src/codegen/dialect/gpu/procedure/read.rs crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct ReadGlobalWithLayout {
2323
}
2424

2525
impl ReadGlobal {
26+
#[allow(missing_docs)]
2627
pub fn expand(self, scope: &mut Scope) {
2728
scope.register(Operator::Index(BinaryOperator {
2829
lhs: self.global,
@@ -61,6 +62,7 @@ impl ReadGlobalWithLayout {
6162
})
6263
}
6364

65+
#[allow(missing_docs)]
6466
pub fn expand(self, scope: &mut Scope) {
6567
let outputs = self.outs;
6668
let tensors = self.globals;
@@ -107,6 +109,7 @@ impl ReadGlobalWithLayout {
107109

108110
/// Calculate the index offset for all tensor variables provided compatible with the given layout.
109111
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112+
#[allow(missing_docs)]
110113
pub struct IndexOffsetGlobalWithLayout {
111114
/// Tensor [variables](Variable), same length as [indexes](Self::indexes).
112115
pub tensors: Vec<Variable>,
@@ -123,6 +126,7 @@ pub struct IndexOffsetGlobalWithLayout {
123126
}
124127

125128
impl IndexOffsetGlobalWithLayout {
129+
#[allow(missing_docs)]
126130
pub fn expand(self, scope: &mut Scope) {
127131
let layout = self.layout;
128132
let index_item_ty = Item::Scalar(Elem::UInt);

crates/burn-wgpu/src/codegen/dialect/gpu/procedure/write.rs crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ use serde::{Deserialize, Serialize};
33

44
/// Write to a global array.
55
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
6+
#[allow(missing_docs)]
67
pub struct WriteGlobal {
78
pub input: Variable,
89
pub global: Variable,
910
}
1011

1112
impl WriteGlobal {
13+
#[allow(missing_docs)]
1214
pub fn expand(self, scope: &mut Scope) {
1315
let output = self.global;
1416
let input = self.input;

0 commit comments

Comments
 (0)