-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xe: sdpa: Improve performance of quantization with better alignment and prefetching #2322
base: main
Are you sure you want to change the base?
Conversation
make test |
src/gpu/intel/ocl/micro_sdpa.cl
Outdated
#if VAL_SCALES == QUANTIZE_2D | ||
/* Prefetch V scales. */ | ||
cooperative_prefetch_2d_maybe_rem(V_scales, d / VAL_GROUP_SIZE, k - k0, | ||
d / VAL_GROUP_SIZE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll want a constant here to allow the compiler to unroll loops inside cooperative_prefetch_2d_maybe_rem
:
d / VAL_GROUP_SIZE, | |
D_MAX / VAL_GROUP_SIZE, |
9af6de9
to
4746108
Compare
make test |
/* n_sg */ sg_per_wg, | ||
/* sg_size */ SUBGROUP_SIZE, | ||
/* cache */ LSC_LDCC_L1C_L3C); | ||
//return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it improve performance to have the first K tile prefetch here (before loading Q)? IIRC in my earlier testing it was better to delay the first K tile prefetch until after issuing the Q load.
cooperative_prefetch_2d_k( | ||
/* ptr */ K, | ||
/* r */ k, | ||
/* c */ d, // faster than D_MAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not so much that it's faster than D_MAX but rather we need to avoid out-of-bounds prefetches.
cooperative_prefetch_2d_maybe_rem( | ||
/* ptr */ K_scales, | ||
/* r */ k, | ||
/* c */ D_MAX / KEY_GROUP_SIZE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to avoid out-of-bounds prefetches here, and similarly for the zp prefetches:
/* c */ D_MAX / KEY_GROUP_SIZE, | |
/* c */ d / KEY_GROUP_SIZE, |
cooperative_prefetch_2d_k( | ||
/* ptr */ K + (k0 + ugemm_kq_wg_tile_m) * stride_k, | ||
/* r */ k - k0 - ugemm_kq_wg_tile_m, | ||
/* c */ D_MAX, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid OOB access:
/* c */ D_MAX, | |
/* c */ d, |
/* sg_id */ sg_ij, | ||
/* n_sg */ sg_per_wg, | ||
/* sg_size */ SUBGROUP_SIZE, | ||
/* cache */ LSC_LDCC_L1C_L3C); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tile is so small that it doesn't need cooperative prefetching (hence the earlier simpler logic). Does this change improve performance?
Description
This PR improves the performance of the micro SDPA kernel by using prefetching and setting better alignment when generating the microkernels. This change has a significant impact on certain sizes ranging from (0.89x-1.26x) over the original version.