Skip to content

Commit

Permalink
Improve HA4/8 algo
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 28, 2025
1 parent 9e65150 commit 0484f51
Showing 1 changed file with 88 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,24 @@ fn strip_labeling<BT: CubePrimitive>(
}

#[cube(launch)]
fn strip_merge<BT: CubePrimitive>(img: &Tensor<BT>, labels: &Tensor<Atomic<u32>>) {
let batch = ABSOLUTE_POS_Z;
let y = ABSOLUTE_POS_Y * BLOCK_H;
let x = ABSOLUTE_POS_X;
fn strip_merge<BT: CubePrimitive>(
img: &Tensor<BT>,
labels: &Tensor<Atomic<u32>>,
#[comptime] connectivity: Connectivity,
) {
let batch = CUBE_POS_Z;
let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM;
let y = (CUBE_POS_Y + 1) * BLOCK_H;
let x = plane_start_x + UNIT_POS_X;

let img_step = img.stride(2);
let labels_step = labels.stride(1);
let cols = img.shape(3);

if y < labels.shape(1) && x < labels.shape(2) && y > 0 {
if y < labels.shape(1) && x < labels.shape(2) {
let mut mask = 0xffffffffu32;
if cols - CUBE_POS_X * CUBE_DIM_X < 32 {
mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X);
if cols - plane_start_x < 32 {
mask >>= 32 - (cols - plane_start_x);
}

let img_index = batch * img.stride(0) + y * img_step + x;
Expand All @@ -207,11 +212,72 @@ fn strip_merge<BT: CubePrimitive>(img: &Tensor<BT>, labels: &Tensor<Atomic<u32>>
let pixels = plane_ballot(p)[0] & mask;
let pixels_up = plane_ballot(p_up)[0] & mask;

if p && p_up {
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if s_dist == 0 || s_dist_up == 0 {
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
match connectivity {
Connectivity::Four => {
if p && p_up {
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if s_dist == 0 || s_dist_up == 0 {
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
}
}
}
Connectivity::Eight => {
let mut last_dist_vec = SharedMemory::<u32>::new(32);
let mut last_dist_up_vec = SharedMemory::<u32>::new(32);

let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);

if UNIT_POS_PLANE == PLANE_DIM - 1 {
last_dist_vec[UNIT_POS_Z] = start_distance(pixels, 32);
last_dist_up_vec[UNIT_POS_Z] = start_distance(pixels_up, 32);
}

sync_units();

if CUBE_POS_X == 0 || UNIT_POS_Z > 0 {
let last_dist = if UNIT_POS_Z > 0 {
last_dist_vec[UNIT_POS_Z - 1]
} else {
0u32
};
let last_dist_up = if UNIT_POS_Z > 0 {
last_dist_up_vec[UNIT_POS_Z - 1]
} else {
0u32
};

let p_prev =
select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0;
let p_up_prev = select(
UNIT_POS_X > 0,
(pixels_up >> (UNIT_POS_X - 1)) & 1,
last_dist_up,
) != 0;

if p && p_up {
let s_dist = start_distance(pixels, UNIT_POS_X);
let s_dist_up = start_distance(pixels_up, UNIT_POS_X);
if s_dist == 0 || s_dist_up == 0 {
merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);
}
} else if p && p_up_prev && s_dist == 0 {
let s_dist_up_prev = select(
UNIT_POS_X == 0,
last_dist_up - 1,
start_distance(pixels_up, UNIT_POS_X - 1),
);
merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev);
} else if p_prev && p_up && s_dist_up == 0 {
let s_dist_prev = select(
UNIT_POS_X == 0,
last_dist - 1,
start_distance(pixels, UNIT_POS_X - 1),
);
merge(labels, labels_index - 1 - s_dist_prev, labels_index_up);
}
}
}
}
}
Expand All @@ -220,8 +286,9 @@ fn strip_merge<BT: CubePrimitive>(img: &Tensor<BT>, labels: &Tensor<Atomic<u32>>
#[cube(launch)]
fn relabeling<BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<u32>) {
let batch = ABSOLUTE_POS_Z;
let plane_start_x = CUBE_POS_X * CUBE_DIM_X;
let y = ABSOLUTE_POS_Y;
let x = ABSOLUTE_POS_X;
let x = plane_start_x + UNIT_POS_X;

let cols = labels.shape(2);
let rows = labels.shape(1);
Expand All @@ -230,8 +297,8 @@ fn relabeling<BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<u32>) {

if x < cols && y < rows {
let mut mask = 0xffffffffu32;
if cols - CUBE_POS_X * CUBE_DIM_X < 32 {
mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X);
if cols - plane_start_x < 32 {
mask >>= 32 - (cols - plane_start_x);
}

let img_index = batch * img.stride(0) + y * img_step + x;
Expand Down Expand Up @@ -372,18 +439,21 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
connectivity,
);

let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32);
let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps);
let cube_count = CubeCount::Static(
(cols as u32).div_ceil(cube_dim.x),
(rows as u32).div_ceil(BLOCK_H).div_ceil(cube_dim.y),
Ord::max((cols as u32 + warp_size * 30 - 1) / (warp_size * 31), 1),
(rows as u32 - 1) / BLOCK_H,
batches as u32,
);

strip_merge::launch::<BT, R>(
&client,
cube_count,
cube_dim,
cube_dim_merge,
img.as_tensor_arg::<u8>(1),
labels.as_tensor_arg::<u32>(1),
connectivity,
);

let cube_count = CubeCount::Static(
Expand Down

0 comments on commit 0484f51

Please sign in to comment.