-
Notifications
You must be signed in to change notification settings - Fork 468
[Refactor]: optimize poly_norm backward kernel pointer handling #1018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Refactor the backward kernel to compute base pointers within the loop instead of incrementing pointers. This improves code clarity and maintainability while maintaining the same performance.
|
Benchmark: |
src/liger_kernel/ops/poly_norm.py
Outdated
| dy_base = dY_ptr + row_idx * dY_row_stride | ||
| x_base = X_ptr + row_idx * X_row_stride | ||
| dx_base = dX_ptr + row_idx * dX_row_stride | ||
| r_base = RSTD_ptr + row_idx * RSTD_row_stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name it rstd_base for consistency
| r_base = RSTD_ptr + row_idx * RSTD_row_stride | |
| rstd_base = RSTD_ptr + row_idx * RSTD_row_stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out — that was my mistake.
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!

Refactor the backward kernel to compute base pointers within the loop instead of incrementing pointers. This improves code clarity and maintainability while maintaining the same performance.
Hardware Type: NVIDIA A100-SXM4-80GB
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence