Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add new burn-vision crate with one initial op #2753

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e7a84b9
Update cubecl
wingertge Jan 25, 2025
267ec63
Update to scope merge
wingertge Jan 26, 2025
71584e0
Fix bitwise shift
wingertge Jan 26, 2025
5838364
Initial JIT implementation
wingertge Jan 28, 2025
d292f85
Merge branch 'main' into feat/burn-vision
wingertge Jan 28, 2025
9e65150
Move testgen to burn-jit
wingertge Jan 28, 2025
0484f51
Improve HA4/8 algo
wingertge Jan 28, 2025
f62a9ee
Terminate units past the predefined 32 plane size
wingertge Jan 28, 2025
8edac2b
move jit backend back into `burn-vision` and make tests work
wingertge Jan 30, 2025
05b40e3
Add initial CPU implementation without stats
wingertge Jan 30, 2025
7708993
Implement stats
wingertge Jan 31, 2025
aeea3a8
Implement all backends except fusion
wingertge Jan 31, 2025
a994ca7
Fix autodiff to use GPU when available
wingertge Jan 31, 2025
866307b
Fixes and cleanup
wingertge Jan 31, 2025
a8e3994
Add docs
wingertge Jan 31, 2025
021360b
Update cubecl
wingertge Jan 31, 2025
a1d727f
Merge branch 'main' into feat/burn-vision
wingertge Jan 31, 2025
01ff01b
Compact labels for JIT
wingertge Feb 1, 2025
d790113
Improve JIT backend implementation by adding label compaction
wingertge Feb 2, 2025
15c431c
Use GPU reduction for max label
wingertge Feb 2, 2025
e3ec085
Manually fuse presence and prefix sum
wingertge Feb 2, 2025
11c8f1f
Make prefix sum more generic over line size
wingertge Feb 2, 2025
ee5ad73
Merge branch 'main' into feat/burn-vision
wingertge Feb 3, 2025
e6126c8
Add vision tests to xtask
wingertge Feb 3, 2025
1bbf50a
Fix CPU and other review stuff
wingertge Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 68 additions & 72 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
version.workspace = true

[features]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
autotune = ["burn-jit/autotune"]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]
Expand Down
6 changes: 5 additions & 1 deletion crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@ version.workspace = true

[features]
autotune = []
default = ["autotune", "std", "fusion", "cubecl/default"]
default = ["autotune", "std", "fusion", "vision", "cubecl/default"]
doc = ["default"]
export_tests = [
"burn-tensor-testgen",
"serial_test",
"burn-autodiff/export_tests",
"burn-tensor/export_tests",
"burn-vision/export_tests",
"burn-ndarray",
"fusion",
"vision",
"paste",
]
fusion = ["burn-fusion"]
fusion-experimental = ["fusion"]
std = ["cubecl/std", "burn-tensor/std"]
vision = ["burn-vision"]

template = []

Expand All @@ -37,6 +40,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features =
"cubecl",
"repr",
] }
burn-vision = { path = "../burn-vision", version = "0.17.0", optional = true }
cubecl = { workspace = true, features = ["linalg", "reduce"] }

bytemuck = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState {
}

impl Init for FusedMatmulStateExpand {
fn init(self, _context: &mut CubeContext) -> Self {
fn init(self, _context: &mut Scope) -> Self {
self
}
}
4 changes: 2 additions & 2 deletions crates/burn-jit/src/fusion/on_write/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ impl CubeType for Arg {
}

impl Init for Arg {
fn init(self, _context: &mut CubeContext) -> Self {
fn init(self, _context: &mut Scope) -> Self {
self
}
}

impl IntoRuntime for Arg {
fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType {
self
}
}
Expand Down
12 changes: 4 additions & 8 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{
use cubecl::{
flex32,
ir::{Elem, FloatKind},
linalg::matmul::{self},
linalg::matmul::{self, kernels::MatmulLaunchError},
tensor_line_size, tf32, Feature,
};
use half::{bf16, f16};
Expand Down Expand Up @@ -195,18 +195,14 @@ where
let cube_count = Alg::cube_count(&selection, &problem);

let advanced_config = Default::default();
let config = match Alg::make_config(
let config = Alg::make_config(
config_input,
&problem,
&cube_dim,
&cube_count,
&advanced_config,
) {
Ok(val) => val,
Err(err) => {
panic!("Can't launch conv kernel because of an invalid config: {err}")
}
};
)
.map_err(MatmulLaunchError::InvalidConfig)?;

let bias = bias.unwrap_or_else(|| {
empty_device::<R, SP::EG>(input.client.clone(), input.device.clone(), Shape::new([1]))
Expand Down
38 changes: 25 additions & 13 deletions crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,38 @@ fn im2col_kernel<F: Float>(
}

#[cfg(not(test))]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX);
pub(crate) fn batches_per_run(
batch_size: usize,
out_h: usize,
out_w: usize,
) -> Result<usize, ConvLaunchError> {
use cubecl::linalg::matmul::kernels::MatmulAvailabilityError;

let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
if max_simultaneous == 0 {
return None;
return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(
cube_count_per_batch as u32,
1,
1,
))
.into());
}
Some(
(0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.expect("Logically not possible"),
)
Ok((0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.expect("Logically not possible"))
}

#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
Some(1)
pub(crate) fn batches_per_run(
batch_size: usize,
out_h: usize,
out_w: usize,
) -> Result<usize, ConvLaunchError> {
Ok(1)
}

fn im2col<R: JitRuntime, E: FloatElement>(
Expand Down Expand Up @@ -214,8 +227,7 @@ pub fn conv2d_im2col<R: JitRuntime, E: FloatElement>(
return execute_1x1_kernel::<R, E>(input, weight, bias, options);
}

let batches_per_run = batches_per_run(batch_size, out_h, out_w)
.expect("Image too large to run even one batch at once");
let batches_per_run = batches_per_run(batch_size, out_h, out_w)?;
let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]);

let mut out = if batches_per_run != batch_size {
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-jit/src/kernel/conv/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::fmt::Debug;
use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError};
use cubecl::{
linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError},
tune::AutotuneError,
};

pub enum ConvLaunchError {
Matmul(MatmulLaunchError),
Expand Down Expand Up @@ -30,6 +33,12 @@ impl From<MatmulLaunchError> for ConvLaunchError {
}
}

impl From<MatmulAvailabilityError> for ConvLaunchError {
fn from(value: MatmulAvailabilityError) -> Self {
Self::Matmul(MatmulLaunchError::Unavailable(value))
}
}

#[allow(clippy::from_over_into)]
impl Into<AutotuneError> for ConvLaunchError {
fn into(self) -> AutotuneError {
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pub mod quantization;
/// Reduction algorithms
pub mod reduce;

/// Vision algorithms
#[cfg(feature = "vision")]
pub mod vision;

pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;
Loading
Loading