Skip to content

Commit

Permalink
Make prefix sum more generic over line size
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Feb 2, 2025
1 parent e3ec085 commit 11c8f1f
Showing 1 changed file with 22 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ fn prefix_sum_kernel<I: Int>(
let mut broadcast = SharedMemory::<I>::new(1);
let mut reduce = SharedMemory::<I>::new(MAX_REDUCE_SIZE);
let batch = CUBE_POS_Z;
let vec4_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size());
let nums_per_cube = CUBE_SIZE * vec4_spt;
let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size());
let nums_per_cube = CUBE_SIZE * line_spt;
let v_last = comptime!(scan_in.line_size() - 1);

//acquire partition index
if UNIT_POS_X == 0 {
Expand All @@ -39,7 +40,7 @@ fn prefix_sum_kernel<I: Int>(

let plane_id = UNIT_POS_X / PLANE_DIM;
let dev_offs = part_id * nums_per_cube;
let plane_offs = plane_id * PLANE_DIM * vec4_spt;
let plane_offs = plane_id * PLANE_DIM * line_spt;

// Exit if full plane is out of bounds
if dev_offs + plane_offs >= scan_in.shape(1) {
Expand All @@ -55,37 +56,35 @@ fn prefix_sum_kernel<I: Int>(
let red_offs = batch * reduction.stride(0);
let scan_offs = batch * scan_in.stride(0);

let mut t_scan = Array::<Line<I>>::vectorized(vec4_spt, scan_in.line_size());
let mut t_scan = Array::<Line<I>>::vectorized(line_spt, scan_in.line_size());
{
let mut i = dev_offs + plane_offs + UNIT_POS_PLANE;

if part_id < cube_count_x - 1 {
for k in 0..vec4_spt {
for k in 0..line_spt {
// Manually fuse not_equal and cast
let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
let x = scan[0];
scan[1] += x;
let y = scan[1];
scan[2] += y;
let z = scan[2];
scan[3] += z;
#[unroll]
for v in 1..scan_in.line_size() {
let prev = scan[v - 1];
scan[v] += prev;
}
t_scan[k] = scan;
i += PLANE_DIM;
}
}

if part_id == cube_count_x - 1 {
for k in 0..vec4_spt {
for k in 0..line_spt {
if i < scan_in.shape(1) {
// Manually fuse not_equal and cast
let mut scan =
Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero)));
let x = scan[0];
scan[1] += x;
let y = scan[1];
scan[2] += y;
let z = scan[2];
scan[3] += z;
#[unroll]
for v in 1..scan_in.line_size() {
let prev = scan[v - 1];
scan[v] += prev;
}
t_scan[k] = scan;
}
i += PLANE_DIM;
Expand All @@ -95,8 +94,8 @@ fn prefix_sum_kernel<I: Int>(
let mut prev = zero;
let plane_mask = PLANE_DIM - 1;
let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask;
for k in 0..vec4_spt {
let t = plane_broadcast(plane_inclusive_sum(t_scan[k][3]), circular_shift);
for k in 0..line_spt {
let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift);
t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev);
prev += plane_broadcast(t, 0);
}
Expand Down Expand Up @@ -187,19 +186,19 @@ fn prefix_sum_kernel<I: Int>(
zero
};
let prev = Line::cast_from(broadcast[0] + prev);
let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * vec4_spt;
let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt;
let dev_offset = part_id * nums_per_cube;
let mut i = s_offset + dev_offset;

if part_id < cube_count_x - 1 {
for k in 0..vec4_spt {
for k in 0..line_spt {
scan_out[i + scan_offs] = t_scan[k] + prev;
i += PLANE_DIM;
}
}

if part_id == cube_count_x - 1 {
for k in 0..vec4_spt {
for k in 0..line_spt {
if i < scan_out.shape(1) {
scan_out[i + scan_offs] = t_scan[k] + prev;
}
Expand Down

0 comments on commit 11c8f1f

Please sign in to comment.