Skip to content

Commit d790113

Browse files
committed
Improve JIT backend implementation by adding label compaction
1 parent 01ff01b commit d790113

File tree

11 files changed

+332
-64
lines changed

11 files changed

+332
-64
lines changed

Cargo.lock

-15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
153153
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
154154

155155
### For the main burn branch. ###
156-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
157-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
156+
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
157+
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
158158
### For local development. ###
159-
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
160-
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
159+
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
160+
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
161161
### For the release. ###
162162
# cubecl = { version = "0.4.0", default-features = false }
163163
# cubecl-common = { version = "0.4.0", default-features = false }

crates/burn-jit/src/ops/base.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ pub(crate) fn expand<R: JitRuntime>(tensor: JitTensor<R>, target_shape: Shape) -
136136
}
137137
}
138138

139-
pub(crate) fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
139+
/// Reshape a jit tensor to a new shape
140+
pub fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
140141
// TODO: Not force standard layout all the time (improve performance).
141142
let tensor = kernel::into_contiguous(tensor);
142143

crates/burn-vision/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ tch = ["burn-tch"]
2525
# Test features
2626
cpu = ["export-tests"]
2727
cuda = ["jit-backend", "export-tests"]
28-
vulkan = ["burn-wgpu/vulkan", "wgpu"]
28+
vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"]
2929
wgpu = ["jit-backend", "export-tests"]
3030

3131
[dependencies]

crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs

+33-23
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
//! DASIP, 2018
55
66
use crate::{
7-
backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions,
8-
ConnectedStatsPrimitive, Connectivity,
7+
backends::jit::{connected_components::stats_from_opts, prefix_sum::prefix_sum},
8+
ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity,
99
};
1010
use burn_jit::{
1111
kernel,
1212
ops::{into_data_sync, numeric::zeros_device},
1313
tensor::JitTensor,
1414
BoolElement, FloatElement, IntElement, JitBackend, JitRuntime,
1515
};
16-
use burn_tensor::ops::IntTensorOps;
17-
use cubecl::{calculate_cube_count_elemwise, prelude::*, Feature};
16+
use burn_tensor::{ops::IntTensorOps, Shape};
17+
use cubecl::{prelude::*, Feature};
1818

1919
const BLOCK_H: u32 = 4;
2020

@@ -380,6 +380,7 @@ fn analysis<I: Int, BT: CubePrimitive>(
380380
while label != u32::cast_from(labels[b_offs + label]) - 1 {
381381
label = u32::cast_from(labels[b_offs + label]) - 1;
382382
}
383+
label += 1;
383384

384385
Atomic::add(&area[b_offs + label], I::cast_from(count));
385386

@@ -397,13 +398,17 @@ fn analysis<I: Int, BT: CubePrimitive>(
397398
label = plane_broadcast(label, UNIT_POS_X - s_dist);
398399

399400
if p {
400-
labels[labels_index] = I::cast_from(label + 1);
401+
labels[labels_index] = I::cast_from(label);
401402
}
402403
}
403404
}
404405

405406
#[cube(launch)]
406-
fn compact_labels<I: Int>(labels: &mut Tensor<I>, remap: &Tensor<I>) {
407+
fn compact_labels<I: Int>(
408+
labels: &mut Tensor<I>,
409+
remap: &Tensor<I>,
410+
max_label: &Tensor<Atomic<I>>,
411+
) {
407412
let batch = ABSOLUTE_POS_Z;
408413
let x = ABSOLUTE_POS_X;
409414
let y = ABSOLUTE_POS_Y;
@@ -416,7 +421,9 @@ fn compact_labels<I: Int>(labels: &mut Tensor<I>, remap: &Tensor<I>) {
416421

417422
let label = u32::cast_from(labels[labels_pos]);
418423
if label != 0 {
419-
labels[labels_pos] = remap[label];
424+
let new_label = remap[label];
425+
labels[labels_pos] = new_label;
426+
Atomic::max(&max_label[batch], new_label);
420427
}
421428
}
422429

@@ -433,11 +440,9 @@ fn compact_stats<I: Int>(
433440
bottom: &Tensor<I>,
434441
bottom_new: &mut Tensor<I>,
435442
remap: &Tensor<I>,
436-
max_label: u32,
437-
#[comptime] opts: ConnectedStatsOptions,
438443
) {
439444
let label = ABSOLUTE_POS_X;
440-
if label > max_label {
445+
if label >= remap.len() {
441446
terminate!();
442447
}
443448

@@ -448,12 +453,12 @@ fn compact_stats<I: Int>(
448453
let new_label = u32::cast_from(remap[label]);
449454

450455
area_new[new_label] = area;
451-
if opts.bounds_enabled {
452-
top_new[new_label] = top[label];
453-
left_new[new_label] = left[label];
454-
right_new[new_label] = right[label];
455-
bottom_new[new_label] = bottom[label];
456-
}
456+
// This should be gated but there's a problem with the Eq bound only being implemented for tuples
457+
// up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work.
458+
top_new[new_label] = top[label];
459+
left_new[new_label] = left[label];
460+
right_new[new_label] = right[label];
461+
bottom_new[new_label] = bottom[label];
457462
}
458463

459464
#[allow(clippy::type_complexity)]
@@ -525,7 +530,7 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
525530
batches as u32,
526531
);
527532

528-
let stats = stats_from_opts(labels.clone(), stats_opt);
533+
let mut stats = stats_from_opts(labels.clone(), stats_opt);
529534

530535
if stats_opt == ConnectedStatsOptions::none() {
531536
relabeling::launch::<I, BT, R>(
@@ -553,28 +558,35 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
553558
if stats_opt.compact_labels {
554559
let max_labels = into_data_sync::<R, I>(stats.max_label.clone()).convert::<u32>();
555560
let max_label = *max_labels.as_slice::<u32>().unwrap().iter().max().unwrap() as usize;
556-
let sliced = kernel::slice::<R, I>(stats.area.clone(), &[0..batches, 0..max_label + 1]);
561+
let sliced = kernel::slice::<R, I>(
562+
stats.area.clone(),
563+
&[0..batches, 0..(max_label + 1).next_multiple_of(4)],
564+
);
557565
let present = JitBackend::<R, F, I, BT>::int_not_equal_elem(sliced, I::new(0));
558-
let relabel = JitBackend::<R, F, I, BT>::int_prefix_sum(present);
566+
let present = kernel::cast::<R, BT, I>(present);
567+
let relabel = prefix_sum::<R, I>(present);
559568

560569
let cube_dim = CubeDim::default();
561570
let cube_count = CubeCount::new_3d(
562571
(cols as u32).div_ceil(cube_dim.x),
563572
(rows as u32).div_ceil(cube_dim.y),
564573
batches as u32,
565574
);
566-
compact_labels::launch(
575+
stats.max_label =
576+
zeros_device::<R, I>(client.clone(), device.clone(), Shape::new([batches]));
577+
compact_labels::launch::<I, R>(
567578
&client,
568579
cube_count,
569580
cube_dim,
570581
labels.as_tensor_arg::<I>(1),
571582
relabel.as_tensor_arg::<I>(1),
583+
stats.max_label.as_tensor_arg::<I>(1),
572584
);
573585

574586
let cube_dim = CubeDim::new_1d(256);
575587
let cube_count =
576588
CubeCount::new_3d((rows * cols).div_ceil(256) as u32, 1, batches as u32);
577-
compact_stats::launch(
589+
compact_stats::launch::<I, R>(
578590
&client,
579591
cube_count,
580592
cube_dim,
@@ -589,8 +601,6 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
589601
stats.bottom.copy().as_tensor_arg::<I>(1),
590602
stats.bottom.as_tensor_arg::<I>(1),
591603
relabel.as_tensor_arg::<I>(1),
592-
ScalarArg::new(max_label as u32),
593-
stats_opt,
594604
);
595605
}
596606
}
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
mod connected_components;
22
mod ops;
3+
4+
/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
5+
/// to really use it in a general case. Needs more work to use as a normal tensor method.
6+
mod prefix_sum;

crates/burn-vision/src/backends/jit/ops.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ where
4545
impl<B: FusionBackend + VisionOps<B>> VisionOps<Self> for Fusion<B> {
4646
fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {
4747
let batches = img.shape[0];
48-
let height = img.shape[2];
49-
let width = img.shape[3];
48+
let height = img.shape[1];
49+
let width = img.shape[2];
5050
let client = img.client.clone();
5151

5252
#[derive(derive_new::new)]
@@ -92,8 +92,8 @@ impl<B: FusionBackend + VisionOps<B>> VisionOps<Self> for Fusion<B> {
9292
opts: ConnectedStatsOptions,
9393
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
9494
let batches = img.shape[0];
95-
let height = img.shape[2];
96-
let width = img.shape[3];
95+
let height = img.shape[1];
96+
let width = img.shape[2];
9797
let client = img.client.clone();
9898

9999
#[derive(derive_new::new)]

0 commit comments

Comments
 (0)