diff --git a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs index 2c4cf663b6..f22910f442 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs @@ -27,8 +27,9 @@ fn prefix_sum_kernel( let mut broadcast = SharedMemory::::new(1); let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); let batch = CUBE_POS_Z; - let vec4_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); - let nums_per_cube = CUBE_SIZE * vec4_spt; + let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); + let nums_per_cube = CUBE_SIZE * line_spt; + let v_last = comptime!(scan_in.line_size() - 1); //acquire partition index if UNIT_POS_X == 0 { @@ -39,7 +40,7 @@ fn prefix_sum_kernel( let plane_id = UNIT_POS_X / PLANE_DIM; let dev_offs = part_id * nums_per_cube; - let plane_offs = plane_id * PLANE_DIM * vec4_spt; + let plane_offs = plane_id * PLANE_DIM * line_spt; // Exit if full plane is out of bounds if dev_offs + plane_offs >= scan_in.shape(1) { @@ -55,37 +56,35 @@ fn prefix_sum_kernel( let red_offs = batch * reduction.stride(0); let scan_offs = batch * scan_in.stride(0); - let mut t_scan = Array::>::vectorized(vec4_spt, scan_in.line_size()); + let mut t_scan = Array::>::vectorized(line_spt, scan_in.line_size()); { let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; if part_id < cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { // Manually fuse not_equal and cast let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); - let x = scan[0]; - scan[1] += x; - let y = scan[1]; - scan[2] += y; - let z = scan[2]; - scan[3] += z; + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } t_scan[k] = scan; i += PLANE_DIM; } } if part_id == cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { if i < scan_in.shape(1) { // Manually fuse not_equal and cast let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); - let x = scan[0]; - scan[1] += x; - let y = scan[1]; - scan[2] += y; - let z = scan[2]; - scan[3] += z; + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } t_scan[k] = scan; } i += PLANE_DIM; @@ -95,8 +94,8 @@ fn prefix_sum_kernel( let mut prev = zero; let plane_mask = PLANE_DIM - 1; let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; - for k in 0..vec4_spt { - let t = plane_broadcast(plane_inclusive_sum(t_scan[k][3]), circular_shift); + for k in 0..line_spt { + let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); prev += plane_broadcast(t, 0); } @@ -187,19 +186,19 @@ fn prefix_sum_kernel( zero }; let prev = Line::cast_from(broadcast[0] + prev); - let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * vec4_spt; + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt; let dev_offset = part_id * nums_per_cube; let mut i = s_offset + dev_offset; if part_id < cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { scan_out[i + scan_offs] = t_scan[k] + prev; i += PLANE_DIM; } } if part_id == cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { if i < scan_out.shape(1) { scan_out[i + scan_offs] = t_scan[k] + prev; }