@@ -249,6 +249,49 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
249
249
/ VAL_ZP_ELEMENTS_PER_BYTE ;
250
250
#endif
251
251
252
+ #ifdef PREFETCH_K0
253
+ /* Prefetch first K tile. */
254
+ cooperative_prefetch_2d_k (
255
+ /* ptr */ K ,
256
+ /* r */ k ,
257
+ /* c */ d , // faster than D_MAX
258
+ /* rmax */ ugemm_kq_wg_tile_m ,
259
+ /* cmax */ PREFETCH_D_MAX ,
260
+ /* ld */ ldk ,
261
+ /* sg_id */ sg_ij ,
262
+ /* n_sg */ sg_per_wg ,
263
+ /* sg_size */ SUBGROUP_SIZE ,
264
+ /* cache */ LSC_LDCC_L1C_L3C );
265
+ //return;
266
+
267
+ #if KEY_SCALES == QUANTIZE_2D
268
+ cooperative_prefetch_2d_maybe_rem (
269
+ /* ptr */ K_scales ,
270
+ /* r */ k ,
271
+ /* c */ D_MAX / KEY_GROUP_SIZE ,
272
+ /* rmax */ ugemm_kq_wg_tile_m ,
273
+ /* cmax */ PREFETCH_D_MAX / KEY_GROUP_SIZE ,
274
+ /* ld */ ldkq ,
275
+ /* sg_id */ sg_ij ,
276
+ /* n_sg */ sg_per_wg ,
277
+ /* sg_size */ SUBGROUP_SIZE ,
278
+ /* cache */ LSC_LDCC_L1C_L3C );
279
+ #endif
280
+ #if KEY_ZERO_POINTS == QUANTIZE_2D
281
+ cooperative_prefetch_2d_maybe_rem (
282
+ /* ptr */ K_zp ,
283
+ /* r */ k ,
284
+ /* c */ D_MAX / KEY_GROUP_SIZE ,
285
+ /* rmax */ ugemm_kq_wg_tile_m ,
286
+ /* cmax */ PREFETCH_D_MAX / KEY_GROUP_SIZE ,
287
+ /* ld */ ldkq ,
288
+ /* sg_id */ sg_ij ,
289
+ /* n_sg */ sg_per_wg ,
290
+ /* sg_size */ SUBGROUP_SIZE ,
291
+ /* cache */ LSC_LDCC_L1C_L3C );
292
+ #endif
293
+ #endif
294
+
252
295
/* Load Q tile, destined for SLM */
253
296
q_tile_type Q_tile ;
254
297
uint q0_copy = q_tile_sg_n * sg_ij ;
@@ -277,23 +320,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
277
320
#endif
278
321
scale *= 1.442695f ; // log2(e)
279
322
280
- #ifdef PREFETCH_K0
281
- /* Prefetch first K tile. */
282
- cooperative_prefetch_2d_k (K , k , d , ugemm_kq_wg_tile_m , PREFETCH_D_MAX , ldk ,
283
- sg_ij , sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
284
-
285
- #if KEY_SCALES == QUANTIZE_2D
286
- cooperative_prefetch_2d (K_scales , ugemm_kq_wg_tile_m ,
287
- PREFETCH_D_MAX / KEY_GROUP_SIZE , ldkq , sg_ij , sg_per_wg ,
288
- SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
289
- #endif
290
- #if KEY_ZERO_POINTS == QUANTIZE_2D
291
- cooperative_prefetch_2d (K_zp , ugemm_kq_wg_tile_m ,
292
- PREFETCH_D_MAX / KEY_GROUP_SIZE , ldkq , sg_ij , sg_per_wg ,
293
- SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
294
- #endif
295
- #endif
296
-
297
323
/* Initialize S column sums in SLM to -inf */
298
324
const uint n_col_sg = DIV_UP (ugemm_kq_wg_tile_n , SUBGROUP_SIZE * sg_per_wg );
299
325
const float neg_inf = - INFINITY ;
@@ -411,25 +437,48 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
411
437
S_max_tile , S_max_slm , ugemm_kq_wg_tile_n , sg_j0_kq , 0 );
412
438
intel_work_group_barrier_arrive (CLK_LOCAL_MEM_FENCE );
413
439
440
+ int k_chunk = min (k - k0 , ugemm_kq_wg_tile_m );
414
441
#ifdef PREFETCH_V
415
442
/* Prefetch V tile. */
416
- cooperative_prefetch_2d_maybe_rem (V , d , k - k0 , D_MAX ,
417
- (ugemm_kq_wg_tile_m * PREFETCH_D_MAX ) / D_MAX , ldv , sg_ij ,
418
- sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
443
+ cooperative_prefetch_2d_maybe_rem (
444
+ /* ptr */ V ,
445
+ /* r */ D_MAX , // slower than d
446
+ /* c */ k - k0 ,
447
+ /* rmax */ PREFETCH_D_MAX ,
448
+ /* cmax */ ugemm_kq_wg_tile_m ,
449
+ /* ld */ ldv ,
450
+ /* sg_id */ sg_ij ,
451
+ /* n_sg */ sg_per_wg ,
452
+ /* sg_size */ SUBGROUP_SIZE ,
453
+ /* cache */ LSC_LDCC_L1C_L3C );
419
454
420
455
#if VAL_SCALES == QUANTIZE_2D
421
456
/* Prefetch V scales. */
422
- cooperative_prefetch_2d_maybe_rem (V_scales , d / VAL_GROUP_SIZE , k - k0 ,
423
- d / VAL_GROUP_SIZE ,
424
- (ugemm_kq_wg_tile_m * PREFETCH_D_MAX ) / D_MAX , ldvq , sg_ij ,
425
- sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
457
+ cooperative_prefetch_2d_maybe_rem (
458
+ /* ptr */ V_scales ,
459
+ /* r */ (D_MAX / VAL_GROUP_SIZE ),
460
+ /* c */ k - k0 ,
461
+ /* rmax */ PREFETCH_D_MAX / VAL_GROUP_SIZE ,
462
+ /* cmax */ k_chunk ,
463
+ /* ld */ ldvq ,
464
+ /* sg_id */ sg_ij ,
465
+ /* n_sg */ sg_per_wg ,
466
+ /* sg_size */ SUBGROUP_SIZE ,
467
+ /* cache */ LSC_LDCC_L1C_L3C );
426
468
#endif
427
469
#if VAL_ZERO_POINTS == QUANTIZE_2D
428
470
/* Prefetch V zero points. */
429
- cooperative_prefetch_2d_maybe_rem (V_zp , d / VAL_GROUP_SIZE , k - k0 ,
430
- d / VAL_GROUP_SIZE ,
431
- (ugemm_kq_wg_tile_m * PREFETCH_D_MAX ) / D_MAX , ldvq , sg_ij ,
432
- sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
471
+ cooperative_prefetch_2d_maybe_rem (
472
+ /* ptr */ V_zp ,
473
+ /* r */ (D_MAX / VAL_GROUP_SIZE ),
474
+ /* c */ k - k0 ,
475
+ /* rmax */ PREFETCH_D_MAX / VAL_GROUP_SIZE ,
476
+ /* cmax */ k_chunk ,
477
+ /* ld */ ldvq ,
478
+ /* sg_id */ sg_ij ,
479
+ /* n_sg */ sg_per_wg ,
480
+ /* sg_size */ SUBGROUP_SIZE ,
481
+ /* cache */ LSC_LDCC_L1C_L3C );
433
482
#endif
434
483
#endif
435
484
#ifndef ALT_MAX
@@ -513,35 +562,72 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
513
562
#else
514
563
const uint stride_k = 1 ;
515
564
#endif
516
- cooperative_prefetch_2d_k (K + (k0 + ugemm_kq_wg_tile_m ) * stride_k ,
517
- k - k0 - ugemm_kq_wg_tile_m , d , ugemm_kq_wg_tile_m ,
518
- PREFETCH_D_MAX , ldk , sg_ij , sg_per_wg , SUBGROUP_SIZE ,
519
- LSC_LDCC_L1C_L3C );
565
+ cooperative_prefetch_2d_k (
566
+ /* ptr */ K + (k0 + ugemm_kq_wg_tile_m ) * stride_k ,
567
+ /* r */ k - k0 - ugemm_kq_wg_tile_m ,
568
+ /* c */ D_MAX ,
569
+ /* rmax */ ugemm_kq_wg_tile_m ,
570
+ /* cmax */ ugemm_kq_wg_tile_n ,
571
+ /* ld*/ ldk ,
572
+ /* sg_id */ sg_ij ,
573
+ /* n_sg */ sg_per_wg ,
574
+ /* sg_size */ SUBGROUP_SIZE ,
575
+ /* cache*/ LSC_LDCC_L1C_L3C );
520
576
#if KEY_SCALES == QUANTIZE_2D
521
- cooperative_prefetch_2d (
522
- K_scales + ((k0 + ugemm_kq_wg_tile_m ) * ldkq ),
523
- ugemm_kq_wg_tile_m , PREFETCH_D_MAX / KEY_GROUP_SIZE , ldkq ,
524
- sg_ij , sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
577
+ cooperative_prefetch_2d_maybe_rem (
578
+ /* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m ),
579
+ /* r */ k - k0 - ugemm_kq_wg_tile_m ,
580
+ /* c */ D_MAX / KEY_GROUP_SIZE ,
581
+ /* rmax */ ugemm_kq_wg_tile_m ,
582
+ /* cmax */ ugemm_kq_wg_tile_n / KEY_GROUP_SIZE ,
583
+ /* ld */ ldkq ,
584
+ /* sg_id */ sg_ij ,
585
+ /* n_sg */ sg_per_wg ,
586
+ /* sg_size */ SUBGROUP_SIZE ,
587
+ /* cache */ LSC_LDCC_L1C_L3C );
525
588
#endif
526
589
#if KEY_ZERO_POINTS == QUANTIZE_2D
527
- cooperative_prefetch_2d (K_zp + ((k0 + ugemm_kq_wg_tile_m ) * ldkq ),
528
- ugemm_kq_wg_tile_m , PREFETCH_D_MAX / KEY_GROUP_SIZE , ldkq ,
529
- sg_ij , sg_per_wg , SUBGROUP_SIZE , LSC_LDCC_L1C_L3C );
590
+ cooperative_prefetch_2d_maybe_rem (
591
+ /* ptr */ K_zp + (k0 + ugemm_kq_wg_tile_m ),
592
+ /* r */ k - k0 - ugemm_kq_wg_tile_m ,
593
+ /* c */ D_MAX / KEY_GROUP_SIZE ,
594
+ /* rmax */ ugemm_kq_wg_tile_m ,
595
+ /* cmax */ ugemm_kq_wg_tile_n / KEY_GROUP_SIZE ,
596
+ /* ld */ ldkq ,
597
+ /* sg_id */ sg_ij ,
598
+ /* n_sg */ sg_per_wg ,
599
+ /* sg_size */ SUBGROUP_SIZE ,
600
+ /* cache */ LSC_LDCC_L1C_L3C );
530
601
#endif
531
602
}
532
603
#endif
604
+
533
605
#if WITH_ATTN_MASK && defined(PREFETCH_MASK )
534
606
/* Prefetch next mask tile. */
535
607
if (!last ) {
536
608
#if BROADCAST_MASK_Q
537
- cooperative_prefetch_2d (msk + k0 + ugemm_kq_wg_tile_m + sg_i0_kq ,
538
- ugemm_kq_sg_tile_m , 1 , 0 , 0 , 1 , SUBGROUP_SIZE ,
539
- LSC_LDCC_L1UC_L3C );
609
+ cooperative_prefetch_2d_maybe_rem (
610
+ /* ptr */ msk + k0 + ugemm_kq_wg_tile_m ,
611
+ /* r */ k - k0 ,
612
+ /* c */ 1 ,
613
+ /* rmax */ ugemm_kq_wg_tile_m ,
614
+ /* cmax */ 1 ,
615
+ /* ld */ 0 ,
616
+ /* sg_id */ sg_ij ,
617
+ /* n_sg */ sg_per_wg ,
618
+ /* sg_size */ SUBGROUP_SIZE ,
619
+ /* cache */ LSC_LDCC_L1C_L3C );
540
620
#else
541
- cooperative_prefetch_2d (msk + k0 + ugemm_kq_wg_tile_m + sg_i0_kq
542
- + (sg_j0_kq + wg_j0 ) * q ,
543
- ugemm_kq_sg_tile_m , ugemm_kq_sg_tile_n , 0 , 0 , 1 ,
544
- SUBGROUP_SIZE , LSC_LDCC_L1UC_L3C );
621
+ cooperative_prefetch_2d_maybe_rem (/* ptr */ msk + k0 + ugemm_kq_sg_tile_m + (wg_j0 ) * MSK_S2 ,
622
+ /* r */ k - k0 - ugemm_kq_wg_tile_m ,
623
+ /* c */ q - wg_j0 ,
624
+ /* rmax */ ugemm_kq_wg_tile_m ,
625
+ /* cmax */ (ugemm_kq_wg_tile_n * PREFETCH_D_MAX ) / D_MAX ,
626
+ /* ld */ MSK_S2 ,
627
+ /* sg_id */ sg_ij ,
628
+ /* n_sg */ sg_per_wg ,
629
+ /* sg_size */ SUBGROUP_SIZE ,
630
+ /* cache */ LSC_LDCC_L1UC_L3C );
545
631
#endif
546
632
}
547
633
#endif
@@ -554,8 +640,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
554
640
intel_work_group_barrier_arrive (CLK_LOCAL_MEM_FENCE );
555
641
556
642
/* Accumulate A += V * S */
557
- int k_chunk = min (k - k0 , ugemm_kq_wg_tile_m );
558
-
559
643
a_tile_type A_tile1 = ugemm_vs (
560
644
V , ldv , S_slm , ugemm_kq_wg_tile_m , d , ugemm_kq_wg_tile_n ,
561
645
k_chunk , 0 , 0 , 0 , sg_i_vs , sg_j_vs , (local char * )ugemm_slm
0 commit comments