Skip to content

Commit 9af6de9

Browse files
committed
xe: sdpa: Update prefetch functions to improve performance
1 parent 0fd3113 commit 9af6de9

File tree

1 file changed

+132
-48
lines changed

1 file changed

+132
-48
lines changed

src/gpu/intel/ocl/micro_sdpa.cl

+132-48
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,49 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
249249
/ VAL_ZP_ELEMENTS_PER_BYTE;
250250
#endif
251251

252+
#ifdef PREFETCH_K0
253+
/* Prefetch first K tile. */
254+
cooperative_prefetch_2d_k(
255+
/* ptr */ K,
256+
/* r */ k,
257+
/* c */ d, // faster than D_MAX
258+
/* rmax */ ugemm_kq_wg_tile_m,
259+
/* cmax */ PREFETCH_D_MAX,
260+
/* ld */ ldk,
261+
/* sg_id */ sg_ij,
262+
/* n_sg */ sg_per_wg,
263+
/* sg_size */ SUBGROUP_SIZE,
264+
/* cache */ LSC_LDCC_L1C_L3C);
265+
//return;
266+
267+
#if KEY_SCALES == QUANTIZE_2D
268+
cooperative_prefetch_2d_maybe_rem(
269+
/* ptr */ K_scales,
270+
/* r */ k,
271+
/* c */ D_MAX / KEY_GROUP_SIZE,
272+
/* rmax */ ugemm_kq_wg_tile_m,
273+
/* cmax */ PREFETCH_D_MAX / KEY_GROUP_SIZE,
274+
/* ld */ ldkq,
275+
/* sg_id */ sg_ij,
276+
/* n_sg */ sg_per_wg,
277+
/* sg_size */ SUBGROUP_SIZE,
278+
/* cache */ LSC_LDCC_L1C_L3C);
279+
#endif
280+
#if KEY_ZERO_POINTS == QUANTIZE_2D
281+
cooperative_prefetch_2d_maybe_rem(
282+
/* ptr */ K_zp,
283+
/* r */ k,
284+
/* c */ D_MAX / KEY_GROUP_SIZE,
285+
/* rmax */ ugemm_kq_wg_tile_m,
286+
/* cmax */ PREFETCH_D_MAX / KEY_GROUP_SIZE,
287+
/* ld */ ldkq,
288+
/* sg_id */ sg_ij,
289+
/* n_sg */ sg_per_wg,
290+
/* sg_size */ SUBGROUP_SIZE,
291+
/* cache */ LSC_LDCC_L1C_L3C);
292+
#endif
293+
#endif
294+
252295
/* Load Q tile, destined for SLM */
253296
q_tile_type Q_tile;
254297
uint q0_copy = q_tile_sg_n * sg_ij;
@@ -277,23 +320,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
277320
#endif
278321
scale *= 1.442695f; // log2(e)
279322

280-
#ifdef PREFETCH_K0
281-
/* Prefetch first K tile. */
282-
cooperative_prefetch_2d_k(K, k, d, ugemm_kq_wg_tile_m, PREFETCH_D_MAX, ldk,
283-
sg_ij, sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
284-
285-
#if KEY_SCALES == QUANTIZE_2D
286-
cooperative_prefetch_2d(K_scales, ugemm_kq_wg_tile_m,
287-
PREFETCH_D_MAX / KEY_GROUP_SIZE, ldkq, sg_ij, sg_per_wg,
288-
SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
289-
#endif
290-
#if KEY_ZERO_POINTS == QUANTIZE_2D
291-
cooperative_prefetch_2d(K_zp, ugemm_kq_wg_tile_m,
292-
PREFETCH_D_MAX / KEY_GROUP_SIZE, ldkq, sg_ij, sg_per_wg,
293-
SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
294-
#endif
295-
#endif
296-
297323
/* Initialize S column sums in SLM to -inf */
298324
const uint n_col_sg = DIV_UP(ugemm_kq_wg_tile_n, SUBGROUP_SIZE * sg_per_wg);
299325
const float neg_inf = -INFINITY;
@@ -411,25 +437,48 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
411437
S_max_tile, S_max_slm, ugemm_kq_wg_tile_n, sg_j0_kq, 0);
412438
intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE);
413439

440+
int k_chunk = min(k - k0, ugemm_kq_wg_tile_m);
414441
#ifdef PREFETCH_V
415442
/* Prefetch V tile. */
416-
cooperative_prefetch_2d_maybe_rem(V, d, k - k0, D_MAX,
417-
(ugemm_kq_wg_tile_m * PREFETCH_D_MAX) / D_MAX, ldv, sg_ij,
418-
sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
443+
cooperative_prefetch_2d_maybe_rem(
444+
/* ptr */ V,
445+
/* r */ D_MAX, // slower than d
446+
/* c */ k - k0,
447+
/* rmax */ PREFETCH_D_MAX,
448+
/* cmax */ ugemm_kq_wg_tile_m,
449+
/* ld */ ldv,
450+
/* sg_id */ sg_ij,
451+
/* n_sg */ sg_per_wg,
452+
/* sg_size */ SUBGROUP_SIZE,
453+
/* cache */ LSC_LDCC_L1C_L3C);
419454

420455
#if VAL_SCALES == QUANTIZE_2D
421456
/* Prefetch V scales. */
422-
cooperative_prefetch_2d_maybe_rem(V_scales, d / VAL_GROUP_SIZE, k - k0,
423-
d / VAL_GROUP_SIZE,
424-
(ugemm_kq_wg_tile_m * PREFETCH_D_MAX) / D_MAX, ldvq, sg_ij,
425-
sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
457+
cooperative_prefetch_2d_maybe_rem(
458+
/* ptr */ V_scales,
459+
/* r */ (D_MAX / VAL_GROUP_SIZE),
460+
/* c */ k - k0,
461+
/* rmax */ PREFETCH_D_MAX / VAL_GROUP_SIZE,
462+
/* cmax */ k_chunk,
463+
/* ld */ ldvq,
464+
/* sg_id */ sg_ij,
465+
/* n_sg */ sg_per_wg,
466+
/* sg_size */ SUBGROUP_SIZE,
467+
/* cache */ LSC_LDCC_L1C_L3C);
426468
#endif
427469
#if VAL_ZERO_POINTS == QUANTIZE_2D
428470
/* Prefetch V zero points. */
429-
cooperative_prefetch_2d_maybe_rem(V_zp, d / VAL_GROUP_SIZE, k - k0,
430-
d / VAL_GROUP_SIZE,
431-
(ugemm_kq_wg_tile_m * PREFETCH_D_MAX) / D_MAX, ldvq, sg_ij,
432-
sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
471+
cooperative_prefetch_2d_maybe_rem(
472+
/* ptr */ V_zp,
473+
/* r */ (D_MAX / VAL_GROUP_SIZE),
474+
/* c */ k - k0,
475+
/* rmax */ PREFETCH_D_MAX / VAL_GROUP_SIZE,
476+
/* cmax */ k_chunk,
477+
/* ld */ ldvq,
478+
/* sg_id */ sg_ij,
479+
/* n_sg */ sg_per_wg,
480+
/* sg_size */ SUBGROUP_SIZE,
481+
/* cache */ LSC_LDCC_L1C_L3C);
433482
#endif
434483
#endif
435484
#ifndef ALT_MAX
@@ -513,35 +562,72 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
513562
#else
514563
const uint stride_k = 1;
515564
#endif
516-
cooperative_prefetch_2d_k(K + (k0 + ugemm_kq_wg_tile_m) * stride_k,
517-
k - k0 - ugemm_kq_wg_tile_m, d, ugemm_kq_wg_tile_m,
518-
PREFETCH_D_MAX, ldk, sg_ij, sg_per_wg, SUBGROUP_SIZE,
519-
LSC_LDCC_L1C_L3C);
565+
cooperative_prefetch_2d_k(
566+
/* ptr */ K + (k0 + ugemm_kq_wg_tile_m) * stride_k,
567+
/* r */ k - k0 - ugemm_kq_wg_tile_m,
568+
/* c */ D_MAX,
569+
/* rmax */ ugemm_kq_wg_tile_m,
570+
/* cmax */ ugemm_kq_wg_tile_n,
571+
/* ld*/ ldk,
572+
/* sg_id */ sg_ij,
573+
/* n_sg */ sg_per_wg,
574+
/* sg_size */ SUBGROUP_SIZE,
575+
/* cache*/ LSC_LDCC_L1C_L3C);
520576
#if KEY_SCALES == QUANTIZE_2D
521-
cooperative_prefetch_2d(
522-
K_scales + ((k0 + ugemm_kq_wg_tile_m) * ldkq),
523-
ugemm_kq_wg_tile_m, PREFETCH_D_MAX / KEY_GROUP_SIZE, ldkq,
524-
sg_ij, sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
577+
cooperative_prefetch_2d_maybe_rem(
578+
/* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m),
579+
/* r */ k - k0 - ugemm_kq_wg_tile_m,
580+
/* c */ D_MAX / KEY_GROUP_SIZE,
581+
/* rmax */ ugemm_kq_wg_tile_m,
582+
/* cmax */ ugemm_kq_wg_tile_n / KEY_GROUP_SIZE,
583+
/* ld */ ldkq,
584+
/* sg_id */ sg_ij,
585+
/* n_sg */ sg_per_wg,
586+
/* sg_size */ SUBGROUP_SIZE,
587+
/* cache */ LSC_LDCC_L1C_L3C);
525588
#endif
526589
#if KEY_ZERO_POINTS == QUANTIZE_2D
527-
cooperative_prefetch_2d(K_zp + ((k0 + ugemm_kq_wg_tile_m) * ldkq),
528-
ugemm_kq_wg_tile_m, PREFETCH_D_MAX / KEY_GROUP_SIZE, ldkq,
529-
sg_ij, sg_per_wg, SUBGROUP_SIZE, LSC_LDCC_L1C_L3C);
590+
cooperative_prefetch_2d_maybe_rem(
591+
/* ptr */ K_zp + (k0 + ugemm_kq_wg_tile_m),
592+
/* r */ k - k0 - ugemm_kq_wg_tile_m,
593+
/* c */ D_MAX / KEY_GROUP_SIZE,
594+
/* rmax */ ugemm_kq_wg_tile_m,
595+
/* cmax */ ugemm_kq_wg_tile_n / KEY_GROUP_SIZE,
596+
/* ld */ ldkq,
597+
/* sg_id */ sg_ij,
598+
/* n_sg */ sg_per_wg,
599+
/* sg_size */ SUBGROUP_SIZE,
600+
/* cache */ LSC_LDCC_L1C_L3C);
530601
#endif
531602
}
532603
#endif
604+
533605
#if WITH_ATTN_MASK && defined(PREFETCH_MASK)
534606
/* Prefetch next mask tile. */
535607
if (!last) {
536608
#if BROADCAST_MASK_Q
537-
cooperative_prefetch_2d(msk + k0 + ugemm_kq_wg_tile_m + sg_i0_kq,
538-
ugemm_kq_sg_tile_m, 1, 0, 0, 1, SUBGROUP_SIZE,
539-
LSC_LDCC_L1UC_L3C);
609+
cooperative_prefetch_2d_maybe_rem(
610+
/* ptr */ msk + k0 + ugemm_kq_wg_tile_m,
611+
/* r */ k - k0,
612+
/* c */ 1,
613+
/* rmax */ ugemm_kq_wg_tile_m,
614+
/* cmax */ 1,
615+
/* ld */ 0,
616+
/* sg_id */ sg_ij,
617+
/* n_sg */ sg_per_wg,
618+
/* sg_size */ SUBGROUP_SIZE,
619+
/* cache */ LSC_LDCC_L1C_L3C);
540620
#else
541-
cooperative_prefetch_2d(msk + k0 + ugemm_kq_wg_tile_m + sg_i0_kq
542-
+ (sg_j0_kq + wg_j0) * q,
543-
ugemm_kq_sg_tile_m, ugemm_kq_sg_tile_n, 0, 0, 1,
544-
SUBGROUP_SIZE, LSC_LDCC_L1UC_L3C);
621+
cooperative_prefetch_2d_maybe_rem(/* ptr */ msk + k0 + ugemm_kq_sg_tile_m + (wg_j0) * MSK_S2,
622+
/* r */ k - k0 - ugemm_kq_wg_tile_m,
623+
/* c */ q - wg_j0,
624+
/* rmax */ ugemm_kq_wg_tile_m,
625+
/* cmax */ (ugemm_kq_wg_tile_n * PREFETCH_D_MAX) / D_MAX,
626+
/* ld */ MSK_S2,
627+
/* sg_id */ sg_ij,
628+
/* n_sg */ sg_per_wg,
629+
/* sg_size */ SUBGROUP_SIZE,
630+
/* cache */ LSC_LDCC_L1UC_L3C);
545631
#endif
546632
}
547633
#endif
@@ -554,8 +640,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
554640
intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE);
555641

556642
/* Accumulate A += V * S */
557-
int k_chunk = min(k - k0, ugemm_kq_wg_tile_m);
558-
559643
a_tile_type A_tile1 = ugemm_vs(
560644
V, ldv, S_slm, ugemm_kq_wg_tile_m, d, ugemm_kq_wg_tile_n,
561645
k_chunk, 0, 0, 0, sg_i_vs, sg_j_vs, (local char *)ugemm_slm

0 commit comments

Comments
 (0)