Skip to content

Commit 48fb31f

Browse files
authored
[Refactor]: optimize poly_norm backward kernel pointer handling (#1018)
Refactor the backward kernel to compute base pointers within the loop instead of incrementing pointers. This improves code clarity and maintainability while maintaining the same performance. Hardware Type: NVIDIA A100-SXM4-80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 3b79375 commit 48fb31f

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

src/liger_kernel/ops/poly_norm.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -141,20 +141,19 @@ def _poly_norm_backward_kernel(
141141
w1 = tl.load(W_ptr + 1).to(tl.float32)
142142
w2 = tl.load(W_ptr + 2).to(tl.float32)
143143

144-
dY_ptr += row_start * dY_row_stride
145-
dX_ptr += row_start * dX_row_stride
146-
X_ptr += row_start * X_row_stride
147-
RSTD_ptr += row_start * RSTD_row_stride
144+
for row_idx in range(row_start, row_end):
145+
dy_base = dY_ptr + row_idx * dY_row_stride
146+
x_base = X_ptr + row_idx * X_row_stride
147+
dx_base = dX_ptr + row_idx * dX_row_stride
148+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
148149

149-
for _ in range(row_start, row_end):
150-
# Load input and gradient
151-
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
151+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
153152

154153
# Load cached rstd values
155-
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
156-
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
157-
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
154+
rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
155+
rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
156+
rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
158157

159158
# Compute powers
160159
X_pow3 = X_row * X_row * X_row
@@ -191,13 +190,7 @@ def _poly_norm_backward_kernel(
191190
dX_row = grad_x_3 + grad_x_2 + grad_x_1
192191

193192
# Store gradient
194-
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
195-
196-
# Update pointers
197-
dY_ptr += dY_row_stride
198-
dX_ptr += dX_row_stride
199-
X_ptr += X_row_stride
200-
RSTD_ptr += RSTD_row_stride
193+
tl.store(dx_base + col_offsets, dX_row, mask=mask)
201194

202195
# Store accumulated gradients (scalars)
203196
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)

0 commit comments

Comments
 (0)