Skip to content

Commit

Permalink
move jit backend back into burn-vision and make tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 30, 2025
1 parent f62a9ee commit 8edac2b
Show file tree
Hide file tree
Showing 19 changed files with 98 additions and 94 deletions.
40 changes: 7 additions & 33 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 1 addition & 5 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,20 @@ version.workspace = true

[features]
autotune = []
default = ["autotune", "std", "fusion", "vision", "cubecl/default"]
default = ["autotune", "std", "fusion", "cubecl/default"]
doc = ["default"]
export_tests = [
"burn-tensor-testgen",
"serial_test",
"burn-autodiff/export_tests",
"burn-tensor/export_tests",
"burn-vision/export_tests",
"burn-ndarray",
"fusion",
"vision",
"paste",
]
fusion = ["burn-fusion"]
fusion-experimental = ["fusion"]
std = ["cubecl/std", "burn-tensor/std"]
vision = ["burn-vision"]

template = []

Expand All @@ -40,7 +37,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features =
"cubecl",
"repr",
] }
burn-vision = { path = "../burn-vision", version = "0.17.0", optional = true }
cubecl = { workspace = true, features = ["linalg", "reduce"] }

bytemuck = { workspace = true }
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ pub mod quantization;
/// Reduction algorithms
pub mod reduce;

/// Vision algorithms
#[cfg(feature = "vision")]
pub mod vision;

pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;
5 changes: 3 additions & 2 deletions crates/burn-jit/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod qtensor;
mod transaction;

pub(crate) mod base;
pub(crate) use base::*;
pub use base::*;

pub(crate) mod numeric;
/// Numeric utility functions for jit backends
pub mod numeric;
4 changes: 0 additions & 4 deletions crates/burn-jit/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ pub use burn_autodiff;
pub use burn_fusion;
pub use burn_ndarray;
pub use burn_tensor;
pub use burn_vision;
pub use serial_test;

#[macro_export]
Expand All @@ -44,10 +43,7 @@ macro_rules! testgen_all {
};
([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => {
mod jit {
pub use $crate::tests::burn_vision;

burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]);
burn_vision::testgen_all!();

mod kernel {
use super::*;
Expand Down
22 changes: 18 additions & 4 deletions crates/burn-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,25 @@ version.workspace = true


[features]
export_tests = ["burn-tensor-testgen"]
default = ["jit-backend"]
export-tests = ["burn-tensor-testgen"]
jit-backend = ["cubecl", "burn-jit"]

# Test features
cpu = ["export-tests"]
cuda = ["jit-backend", "export-tests"]
vulkan = ["burn-wgpu/vulkan", "wgpu"]
wgpu = ["jit-backend", "export-tests"]

[dependencies]
burn-tensor = { path = "../burn-tensor" }
cubecl = { workspace = true }
burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.17.0" }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }
cubecl = { workspace = true, optional = true }
derive-new = { workspace = true }
serde = { workspace = true }

burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }
[dev-dependencies]
burn-cuda = { path = "../burn-cuda", version = "0.17.0", default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" }
burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false }
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
//! "Optimized Block-Based Algorithms to Label Connected Components on GPUs,"
//! in IEEE Transactions on Parallel and Distributed Systems, 2019.
use crate::{
use burn_jit::{
kernel,
ops::numeric::{empty_device, zeros_device},
tensor::JitTensor,
tests::burn_tensor::{DType, Shape},
JitElement, JitRuntime,
};
use burn_tensor::{DType, Shape};
use cubecl::cube;
use cubecl::prelude::*;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
//! DASIP, 2018
use crate::{
kernel::vision::connected_components::stats_from_opts, ops::numeric::zeros_device,
tensor::JitTensor, BoolElement, FloatElement, IntElement, JitBackend, JitRuntime,
backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions,
ConnectedStatsPrimitive, Connectivity,
};
use burn_jit::{
ops::numeric::zeros_device, tensor::JitTensor, BoolElement, FloatElement, IntElement,
JitBackend, JitRuntime,
};
use burn_tensor::Shape;
use burn_vision::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity};
use cubecl::{prelude::*, Feature};

const BLOCK_H: u32 = 4;
Expand Down Expand Up @@ -50,16 +53,19 @@ fn end_distance(pixels: u32, tx: u32) -> u32 {
u32::find_first_set(u32::bitwise_not(pixels >> (tx + 1)))
}

#[cube]
#[expect(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")]
fn ballot_dyn(y: u32, pred: bool) -> u32 {
let index = y % (PLANE_DIM / 32);
plane_ballot(pred)[index]
}

#[cube(launch)]
fn strip_labeling<BT: CubePrimitive>(
img: &Tensor<BT>,
labels: &Tensor<Atomic<u32>>,
#[comptime] connectivity: Connectivity,
) {
if UNIT_POS_PLANE >= 32 {
terminate!();
}

let mut shared_pixels = SharedMemory::<u32>::new(BLOCK_H);

let batch = ABSOLUTE_POS_Z;
Expand Down Expand Up @@ -95,7 +101,7 @@ fn strip_labeling<BT: CubePrimitive>(

let p_y = bool::cast_from(img[img_index]);

let pixels_y = plane_ballot(p_y)[0] & mask;
let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask;
let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X);

if p_y && s_dist_y == 0 {
Expand Down Expand Up @@ -213,8 +219,8 @@ fn strip_merge<BT: CubePrimitive>(
let p = bool::cast_from(img[img_index]);
let p_up = bool::cast_from(img[img_index_up]);

let pixels = plane_ballot(p)[0] & mask;
let pixels_up = plane_ballot(p_up)[0] & mask;
let pixels = ballot_dyn(UNIT_POS_Z, p) & mask;
let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask;

match connectivity {
Connectivity::Four => {
Expand Down Expand Up @@ -309,7 +315,7 @@ fn relabeling<BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<u32>) {
let labels_index = batch * labels.stride(0) + y * labels_step + x;

let p = bool::cast_from(img[img_index]);
let pixels = plane_ballot(p)[0] & mask;
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
let s_dist = start_distance(pixels, UNIT_POS_X);
let mut label = 0u32;

Expand Down Expand Up @@ -358,7 +364,7 @@ fn analysis<BT: CubePrimitive>(
let labels_index = batch * labels.stride(0) + y * labels_step + x;

let p = bool::cast_from(img[img_index]);
let pixels = plane_ballot(p)[0] & mask;
let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;
let s_dist = start_distance(pixels, UNIT_POS_X);
let count = end_distance(pixels, UNIT_POS_X);
let max_x = x + count - 1;
Expand Down Expand Up @@ -429,7 +435,8 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B

// Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32.
// This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp
// size at compile time.
// size at compile time. `REQUIRE_FULL_SUBGROUPS` or subgroup size controls are not supported
// in wgpu.
let warp_size = 32;
let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H);
let cube_count = CubeCount::Static(1, (rows as u32).div_ceil(cube_dim.y), batches as u32);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use crate::{
mod bke;
mod hardware_accelerated;

use burn_jit::{
ops::numeric::{full_device, zeros_device},
tensor::JitTensor,
BoolElement, FloatElement, IntElement, JitBackend, JitRuntime,
};

mod bke;
mod hardware_accelerated;

use burn_tensor::Shape;
use burn_vision::{ConnectedStatsOptions, ConnectedStatsPrimitive};
pub use hardware_accelerated::*;

use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};

pub(crate) fn stats_from_opts<R, F, I, BT>(
l: JitTensor<R>,
opts: ConnectedStatsOptions,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_tensor::ops::{BoolTensor, IntTensor};
use burn_vision::{
cpu_impl, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps,
use crate::{
backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps,
};
use burn_jit::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_tensor::ops::{BoolTensor, IntTensor};

use super::connected_components::hardware_accelerated;

Expand All @@ -20,7 +20,7 @@ where
connectivity,
)
.map(|it| it.0)
.unwrap_or_else(|_| cpu_impl::connected_components::<Self>(img, connectivity))
.unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))
}

fn connected_components_with_stats(
Expand All @@ -29,7 +29,7 @@ where
opts: ConnectedStatsOptions,
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {
cpu_impl::connected_components_with_stats::<Self>(img, connectivity, opts)
cpu::connected_components_with_stats::<Self>(img, connectivity, opts)
})
}
}
3 changes: 3 additions & 0 deletions crates/burn-vision/src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub(crate) mod cpu;
#[cfg(feature = "jit-backend")]
mod jit;
4 changes: 2 additions & 2 deletions crates/burn-vision/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod cpu_impl;
pub mod backends;
mod ops;
mod tensor;

#[cfg(feature = "export_tests")]
#[cfg(feature = "export-tests")]
mod tests;

pub use ops::*;
Expand Down
7 changes: 3 additions & 4 deletions crates/burn-vision/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::backends::cpu;
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, IntTensor},
Int, Tensor,
};

use crate::cpu_impl;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Connectivity {
Four,
Expand Down Expand Up @@ -80,14 +79,14 @@ impl ConnectedStatsOptions {

pub trait VisionOps<B: Backend> {
fn connected_components(img: BoolTensor<B>, connectivity: Connectivity) -> IntTensor<B> {
cpu_impl::connected_components::<B>(img, connectivity)
cpu::connected_components::<B>(img, connectivity)
}

fn connected_components_with_stats(
img: BoolTensor<B>,
connectivity: Connectivity,
opts: ConnectedStatsOptions,
) -> (IntTensor<B>, ConnectedStatsPrimitive<B>) {
cpu_impl::connected_components_with_stats(img, connectivity, opts)
cpu::connected_components_with_stats(img, connectivity, opts)
}
}
7 changes: 2 additions & 5 deletions crates/burn-vision/src/tests/connected_components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ mod tests {
use std::collections::HashMap;

use super::*;
use burn_tensor::{Tensor, TensorData};
use burn_vision::{
as_type, ConnectedComponents, ConnectedStats, ConnectedStatsOptions, Connectivity,
VisionOps,
};
use burn_tensor::TensorData;
use burn_vision::{as_type, ConnectedComponents, ConnectedStatsOptions, Connectivity};

fn space_invader() -> [[IntType; 14]; 9] {
as_type!(IntType: [
Expand Down
Loading

0 comments on commit 8edac2b

Please sign in to comment.