Skip to content

Commit 77bb921

Browse files
committed
Fix illegal memory accesses in multistage Mma's for k=0
1 parent 637b159 commit 77bb921

11 files changed

+98
-83
lines changed

examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
377377
CUTLASS_PRAGMA_UNROLL
378378
for (int stage = 0; stage < kNumStagesConcurrentLoad;
379379
++stage, --gemm_k_iterations) {
380-
iterator_A.clear_mask(gemm_k_iterations == 0);
381-
iterator_B.clear_mask(gemm_k_iterations == 0);
380+
iterator_A.clear_mask(gemm_k_iterations <= 0);
381+
iterator_B.clear_mask(gemm_k_iterations <= 0);
382382

383383
iterator_A.set_iteration_index(0);
384384
smem_iterator_A_.set_iteration_index(0);
@@ -559,8 +559,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
559559
++this->warp_tile_iterator_A_;
560560
++this->warp_tile_iterator_B_;
561561

562-
iterator_A.clear_mask(gemm_k_iterations == 0);
563-
iterator_B.clear_mask(gemm_k_iterations == 0);
562+
iterator_A.clear_mask(gemm_k_iterations <= 0);
563+
iterator_B.clear_mask(gemm_k_iterations <= 0);
564564

565565
int smem_write_stage_idx = Base::kStages - 1;
566566
int smem_read_stage_idx = 0;
@@ -725,8 +725,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
725725
}
726726

727727
--gemm_k_iterations;
728-
iterator_A.clear_mask(gemm_k_iterations == 0);
729-
iterator_B.clear_mask(gemm_k_iterations == 0);
728+
iterator_A.clear_mask(gemm_k_iterations <= 0);
729+
iterator_B.clear_mask(gemm_k_iterations <= 0);
730730
}
731731

732732
// Do any conversions feeding the first stage at the end of the loop so

examples/45_dual_gemm/threadblock/dual_mma_multistage.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ class DualMmaMultistage :
363363
for (int stage = 0; stage < Base::kStages - 1;
364364
++stage, --gemm_k_iterations) {
365365

366-
iterator_A.clear_mask(gemm_k_iterations == 0);
367-
iterator_B0.clear_mask(gemm_k_iterations == 0);
368-
iterator_B1.clear_mask(gemm_k_iterations == 0);
366+
iterator_A.clear_mask(gemm_k_iterations <= 0);
367+
iterator_B0.clear_mask(gemm_k_iterations <= 0);
368+
iterator_B1.clear_mask(gemm_k_iterations <= 0);
369369

370370
iterator_A.set_iteration_index(0);
371371
this->smem_iterator_A_.set_iteration_index(0);
@@ -555,9 +555,9 @@ class DualMmaMultistage :
555555
++this->warp_tile_iterator_B0_;
556556
++this->warp_tile_iterator_B1_;
557557

558-
iterator_A.clear_mask(gemm_k_iterations == 0);
559-
iterator_B0.clear_mask(gemm_k_iterations == 0);
560-
iterator_B1.clear_mask(gemm_k_iterations == 0);
558+
iterator_A.clear_mask(gemm_k_iterations <= 0);
559+
iterator_B0.clear_mask(gemm_k_iterations <= 0);
560+
iterator_B1.clear_mask(gemm_k_iterations <= 0);
561561

562562
int smem_write_stage_idx = Base::kStages - 1;
563563
int smem_read_stage_idx = 0;
@@ -730,9 +730,9 @@ class DualMmaMultistage :
730730
}
731731

732732
--gemm_k_iterations;
733-
iterator_A.clear_mask(gemm_k_iterations == 0);
734-
iterator_B0.clear_mask(gemm_k_iterations == 0);
735-
iterator_B1.clear_mask(gemm_k_iterations == 0);
733+
iterator_A.clear_mask(gemm_k_iterations <= 0);
734+
iterator_B0.clear_mask(gemm_k_iterations <= 0);
735+
iterator_B1.clear_mask(gemm_k_iterations <= 0);
736736
}
737737

738738
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/ell_mma_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ class EllMmaMultistage :
332332
for (int stage = 0; stage < Base::kStages - 1;
333333
++stage, --gemm_k_iterations) {
334334

335-
iterator_A.clear_mask(gemm_k_iterations == 0);
336-
iterator_B.clear_mask(gemm_k_iterations == 0);
335+
iterator_A.clear_mask(gemm_k_iterations <= 0);
336+
iterator_B.clear_mask(gemm_k_iterations <= 0);
337337

338338
iterator_A.set_iteration_index(0);
339339
this->smem_iterator_A_.set_iteration_index(0);
@@ -456,8 +456,8 @@ class EllMmaMultistage :
456456
++this->warp_tile_iterator_A_;
457457
++this->warp_tile_iterator_B_;
458458

459-
iterator_A.clear_mask(gemm_k_iterations == 0);
460-
iterator_B.clear_mask(gemm_k_iterations == 0);
459+
iterator_A.clear_mask(gemm_k_iterations <= 0);
460+
iterator_B.clear_mask(gemm_k_iterations <= 0);
461461

462462
if (is_A_sparse){
463463
iterator_A.ell_add_mask(ell_iterator.get_blocksize());
@@ -608,8 +608,8 @@ class EllMmaMultistage :
608608
}
609609

610610
--gemm_k_iterations;
611-
iterator_A.clear_mask(gemm_k_iterations == 0);
612-
iterator_B.clear_mask(gemm_k_iterations == 0);
611+
iterator_A.clear_mask(gemm_k_iterations <= 0);
612+
iterator_B.clear_mask(gemm_k_iterations <= 0);
613613
}
614614

615615
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/mma_blas3_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ class MmaBlas3Multistage :
339339
for (int stage = 0; stage < Base::kStages - 1;
340340
++stage, --gemm_k_iterations) {
341341

342-
iterator_A.clear_mask(gemm_k_iterations == 0);
343-
iterator_B.clear_mask(gemm_k_iterations == 0);
342+
iterator_A.clear_mask(gemm_k_iterations <= 0);
343+
iterator_B.clear_mask(gemm_k_iterations <= 0);
344344

345345
iterator_A.set_iteration_index(0);
346346
this->smem_iterator_A_.set_iteration_index(0);
@@ -519,8 +519,8 @@ class MmaBlas3Multistage :
519519
++this->warp_tile_iterator_A_;
520520
++this->warp_tile_iterator_B_;
521521

522-
iterator_A.clear_mask(gemm_k_iterations == 0);
523-
iterator_B.clear_mask(gemm_k_iterations == 0);
522+
iterator_A.clear_mask(gemm_k_iterations <= 0);
523+
iterator_B.clear_mask(gemm_k_iterations <= 0);
524524

525525
int smem_write_stage_idx = Base::kStages - 1;
526526
int smem_read_stage_idx = 0;
@@ -661,8 +661,8 @@ class MmaBlas3Multistage :
661661
}
662662

663663
--gemm_k_iterations;
664-
iterator_A.clear_mask(gemm_k_iterations == 0);
665-
iterator_B.clear_mask(gemm_k_iterations == 0);
664+
iterator_A.clear_mask(gemm_k_iterations <= 0);
665+
iterator_B.clear_mask(gemm_k_iterations <= 0);
666666
}
667667

668668
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,9 @@ class MmaLayernormMainloopFusionMultistage :
572572
for (int stage = 0; stage < Base::kStages - 1;
573573
++stage, --gemm_k_iterations) {
574574

575-
iterator_A.clear_mask(gemm_k_iterations == 0);
576-
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
577-
iterator_B.clear_mask(gemm_k_iterations == 0);
575+
iterator_A.clear_mask(gemm_k_iterations <= 0);
576+
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
577+
iterator_B.clear_mask(gemm_k_iterations <= 0);
578578

579579
iterator_A.set_iteration_index(0);
580580
this->smem_iterator_A_.set_iteration_index(0);
@@ -692,9 +692,9 @@ class MmaLayernormMainloopFusionMultistage :
692692
++this->warp_tile_iterator_A_gamma_beta_;
693693
++this->warp_tile_iterator_B_;
694694

695-
iterator_A.clear_mask(gemm_k_iterations == 0);
696-
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
697-
iterator_B.clear_mask(gemm_k_iterations == 0);
695+
iterator_A.clear_mask(gemm_k_iterations <= 0);
696+
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
697+
iterator_B.clear_mask(gemm_k_iterations <= 0);
698698

699699
int smem_write_stage_idx = Base::kStages - 1;
700700
int smem_read_stage_idx = 0;
@@ -824,9 +824,9 @@ class MmaLayernormMainloopFusionMultistage :
824824
}
825825

826826
--gemm_k_iterations;
827-
iterator_A.clear_mask(gemm_k_iterations == 0);
828-
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
829-
iterator_B.clear_mask(gemm_k_iterations == 0);
827+
iterator_A.clear_mask(gemm_k_iterations <= 0);
828+
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
829+
iterator_B.clear_mask(gemm_k_iterations <= 0);
830830
}
831831

832832
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/mma_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ class MmaMultistage :
370370
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
371371

372372
// Disable global fetching if done with global fetch iterations
373-
iterator_A.clear_mask(gemm_k_iterations == 0);
374-
iterator_B.clear_mask(gemm_k_iterations == 0);
373+
iterator_A.clear_mask(gemm_k_iterations <= 0);
374+
iterator_B.clear_mask(gemm_k_iterations <= 0);
375375

376376
iterator_A.set_iteration_index(0);
377377
this->smem_iterator_A_.set_iteration_index(0);
@@ -588,8 +588,8 @@ class MmaMultistage :
588588

589589
// Disable global fetching when done with global fetch iterations
590590
--gemm_k_iterations;
591-
iterator_A.clear_mask(gemm_k_iterations == 0);
592-
iterator_B.clear_mask(gemm_k_iterations == 0);
591+
iterator_A.clear_mask(gemm_k_iterations <= 0);
592+
iterator_B.clear_mask(gemm_k_iterations <= 0);
593593
}
594594

595595
// The last warp-tile also converts the shared memory fragments used by
@@ -620,8 +620,8 @@ class MmaMultistage :
620620
PipeState pipe_state;
621621

622622
// Disable global fetching if done with global fetch iterations
623-
iterator_A.clear_mask(gemm_k_iterations == 0);
624-
iterator_B.clear_mask(gemm_k_iterations == 0);
623+
iterator_A.clear_mask(gemm_k_iterations <= 0);
624+
iterator_B.clear_mask(gemm_k_iterations <= 0);
625625

626626
// Load first warp-tile's A fragment from shared memory
627627
this->warp_tile_iterator_A_.set_kgroup_index(0);

include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,10 @@ class MmaPlanarComplexMultistage :
374374
for (int stage = 0; stage < Base::kStages - 1;
375375
++stage, --gemm_k_iterations) {
376376

377-
iterator_A_real.clear_mask(gemm_k_iterations == 0);
378-
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
379-
iterator_B_real.clear_mask(gemm_k_iterations == 0);
380-
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
377+
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
378+
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
379+
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
380+
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);
381381

382382
iterator_A_real.set_iteration_index(0);
383383
iterator_A_imag.set_iteration_index(0);
@@ -503,10 +503,10 @@ class MmaPlanarComplexMultistage :
503503
++this->warp_tile_iterator_A_;
504504
++this->warp_tile_iterator_B_;
505505

506-
iterator_A_real.clear_mask(gemm_k_iterations == 0);
507-
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
508-
iterator_B_real.clear_mask(gemm_k_iterations == 0);
509-
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
506+
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
507+
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
508+
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
509+
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);
510510

511511
// Start issuing the first group of the next stage outside of the mainloop
512512
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
@@ -611,10 +611,10 @@ class MmaPlanarComplexMultistage :
611611
}
612612

613613
--gemm_k_iterations;
614-
iterator_A_real.clear_mask(gemm_k_iterations == 0);
615-
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
616-
iterator_B_real.clear_mask(gemm_k_iterations == 0);
617-
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
614+
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
615+
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
616+
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
617+
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);
618618
}
619619

620620
warp_mma_planar_complex(

include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ class MmaSoftmaxMainloopFusionMultistage :
486486
for (int stage = 0; stage < Base::kStages - 1;
487487
++stage, --gemm_k_iterations) {
488488

489-
iterator_A.clear_mask(gemm_k_iterations == 0);
490-
iterator_B.clear_mask(gemm_k_iterations == 0);
489+
iterator_A.clear_mask(gemm_k_iterations <= 0);
490+
iterator_B.clear_mask(gemm_k_iterations <= 0);
491491

492492
iterator_A.set_iteration_index(0);
493493
this->smem_iterator_A_.set_iteration_index(0);
@@ -581,8 +581,8 @@ class MmaSoftmaxMainloopFusionMultistage :
581581
++this->warp_tile_iterator_A_;
582582
++this->warp_tile_iterator_B_;
583583

584-
iterator_A.clear_mask(gemm_k_iterations == 0);
585-
iterator_B.clear_mask(gemm_k_iterations == 0);
584+
iterator_A.clear_mask(gemm_k_iterations <= 0);
585+
iterator_B.clear_mask(gemm_k_iterations <= 0);
586586

587587
// Start issuing the first group of the next stage outside of the mainloop
588588
copy_tiles_and_advance(iterator_A, iterator_B);
@@ -708,8 +708,8 @@ class MmaSoftmaxMainloopFusionMultistage :
708708
}
709709

710710
--gemm_k_iterations;
711-
iterator_A.clear_mask(gemm_k_iterations == 0);
712-
iterator_B.clear_mask(gemm_k_iterations == 0);
711+
iterator_A.clear_mask(gemm_k_iterations <= 0);
712+
iterator_B.clear_mask(gemm_k_iterations <= 0);
713713
}
714714

715715
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/mma_sparse_multistage.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,9 @@ class SparseMmaMultistage :
381381
for (int stage = 0; stage < Base::kStages - 1;
382382
++stage, --gemm_k_iterations) {
383383

384-
iterator_A.clear_mask(gemm_k_iterations == 0);
385-
iterator_B.clear_mask(gemm_k_iterations == 0);
386-
iterator_E.clear_mask(gemm_k_iterations == 0);
384+
iterator_A.clear_mask(gemm_k_iterations <= 0);
385+
iterator_B.clear_mask(gemm_k_iterations <= 0);
386+
iterator_E.clear_mask(gemm_k_iterations <= 0);
387387

388388
iterator_A.set_iteration_index(0);
389389
this->smem_iterator_A_.set_iteration_index(0);
@@ -499,9 +499,9 @@ class SparseMmaMultistage :
499499
++this->warp_tile_iterator_B_;
500500
++this->warp_tile_iterator_E_;
501501

502-
iterator_A.clear_mask(gemm_k_iterations == 0);
503-
iterator_B.clear_mask(gemm_k_iterations == 0);
504-
iterator_E.clear_mask(gemm_k_iterations == 0);
502+
iterator_A.clear_mask(gemm_k_iterations <= 0);
503+
iterator_B.clear_mask(gemm_k_iterations <= 0);
504+
iterator_E.clear_mask(gemm_k_iterations <= 0);
505505

506506
int smem_write_stage_idx = Base::kStages - 1;
507507
int smem_read_stage_idx = 0;
@@ -634,9 +634,9 @@ class SparseMmaMultistage :
634634
}
635635

636636
--gemm_k_iterations;
637-
iterator_A.clear_mask(gemm_k_iterations == 0);
638-
iterator_B.clear_mask(gemm_k_iterations == 0);
639-
iterator_E.clear_mask(gemm_k_iterations == 0);
637+
iterator_A.clear_mask(gemm_k_iterations <= 0);
638+
iterator_B.clear_mask(gemm_k_iterations <= 0);
639+
iterator_E.clear_mask(gemm_k_iterations <= 0);
640640
}
641641

642642
// Do any conversions feeding the first stage at the end of the loop so

include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ class MmaWithReductionMultistage :
310310
for (int stage = 0; stage < Base::kStages - 1;
311311
++stage, --gemm_k_iterations) {
312312

313-
iterator_A.clear_mask(gemm_k_iterations == 0);
314-
iterator_B.clear_mask(gemm_k_iterations == 0);
313+
iterator_A.clear_mask(gemm_k_iterations <= 0);
314+
iterator_B.clear_mask(gemm_k_iterations <= 0);
315315

316316
iterator_A.set_iteration_index(0);
317317
this->smem_iterator_A_.set_iteration_index(0);
@@ -403,8 +403,8 @@ class MmaWithReductionMultistage :
403403
++this->warp_tile_iterator_A_;
404404
++this->warp_tile_iterator_B_;
405405

406-
iterator_A.clear_mask(gemm_k_iterations == 0);
407-
iterator_B.clear_mask(gemm_k_iterations == 0);
406+
iterator_A.clear_mask(gemm_k_iterations <= 0);
407+
iterator_B.clear_mask(gemm_k_iterations <= 0);
408408

409409
int smem_write_stage_idx = Base::kStages - 1;
410410
int smem_read_stage_idx = 0;
@@ -513,8 +513,8 @@ class MmaWithReductionMultistage :
513513
}
514514

515515
--gemm_k_iterations;
516-
iterator_A.clear_mask(gemm_k_iterations == 0);
517-
iterator_B.clear_mask(gemm_k_iterations == 0);
516+
iterator_A.clear_mask(gemm_k_iterations <= 0);
517+
iterator_B.clear_mask(gemm_k_iterations <= 0);
518518
}
519519

520520
// Do any conversions feeding the first stage at the end of the loop so

0 commit comments

Comments
 (0)