Skip to content

Commit

Permalink
Manually fuse presence and prefix sum
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Feb 2, 2025
1 parent 15c431c commit e3ec085
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
//! DASIP, 2018
use crate::{
backends::jit::{connected_components::stats_from_opts, prefix_sum::prefix_sum},
ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity,
backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions,
ConnectedStatsPrimitive, Connectivity,
};
use burn_jit::{
kernel,
Expand All @@ -16,6 +16,8 @@ use burn_jit::{
use burn_tensor::{ops::IntTensorOps, Shape};
use cubecl::{prelude::*, Feature};

use super::prefix_sum::prefix_sum;

const BLOCK_H: u32 = 4;

#[cube]
Expand Down Expand Up @@ -563,9 +565,7 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
stats.area.clone(),
&[0..batches, 0..(max_label + 1).next_multiple_of(4)],
);
let present = JitBackend::<R, F, I, BT>::int_not_equal_elem(sliced, I::new(0));
let present = kernel::cast::<R, BT, I>(present);
let relabel = prefix_sum::<R, I>(present);
let relabel = prefix_sum::<R, I>(sliced);

let cube_dim = CubeDim::default();
let cube_count = CubeCount::new_3d(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
mod hardware_accelerated;

/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
/// to really use it in a general case. Needs more work to use as a normal tensor method.
mod prefix_sum;

use burn_jit::{
ops::numeric::{full_device, zeros_device},
tensor::JitTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ fn prefix_sum_kernel<I: Int>(

if part_id < cube_count_x - 1 {
for k in 0..vec4_spt {
let mut scan = scan_in[i + scan_offs];
// Manually fuse not_equal and cast
let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
let x = scan[0];
scan[1] += x;
let y = scan[1];
Expand All @@ -76,7 +77,9 @@ fn prefix_sum_kernel<I: Int>(
if part_id == cube_count_x - 1 {
for k in 0..vec4_spt {
if i < scan_in.shape(1) {
let mut scan = scan_in[i + scan_offs];
// Manually fuse not_equal and cast
let mut scan =
Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
let x = scan[0];
scan[1] += x;
let y = scan[1];
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-vision/src/backends/jit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
mod connected_components;
mod ops;

/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops
/// to really use it in a general case. Needs more work to use as a normal tensor method.
mod prefix_sum;

0 comments on commit e3ec085

Please sign in to comment.