@@ -374,10 +374,10 @@ class MmaPlanarComplexMultistage :
374
374
for (int stage = 0 ; stage < Base::kStages - 1 ;
375
375
++stage, --gemm_k_iterations) {
376
376
377
- iterator_A_real.clear_mask (gemm_k_iterations = = 0 );
378
- iterator_A_imag.clear_mask (gemm_k_iterations = = 0 );
379
- iterator_B_real.clear_mask (gemm_k_iterations = = 0 );
380
- iterator_B_imag.clear_mask (gemm_k_iterations = = 0 );
377
+ iterator_A_real.clear_mask (gemm_k_iterations < = 0 );
378
+ iterator_A_imag.clear_mask (gemm_k_iterations < = 0 );
379
+ iterator_B_real.clear_mask (gemm_k_iterations < = 0 );
380
+ iterator_B_imag.clear_mask (gemm_k_iterations < = 0 );
381
381
382
382
iterator_A_real.set_iteration_index (0 );
383
383
iterator_A_imag.set_iteration_index (0 );
@@ -503,10 +503,10 @@ class MmaPlanarComplexMultistage :
503
503
++this ->warp_tile_iterator_A_ ;
504
504
++this ->warp_tile_iterator_B_ ;
505
505
506
- iterator_A_real.clear_mask (gemm_k_iterations = = 0 );
507
- iterator_A_imag.clear_mask (gemm_k_iterations = = 0 );
508
- iterator_B_real.clear_mask (gemm_k_iterations = = 0 );
509
- iterator_B_imag.clear_mask (gemm_k_iterations = = 0 );
506
+ iterator_A_real.clear_mask (gemm_k_iterations < = 0 );
507
+ iterator_A_imag.clear_mask (gemm_k_iterations < = 0 );
508
+ iterator_B_real.clear_mask (gemm_k_iterations < = 0 );
509
+ iterator_B_imag.clear_mask (gemm_k_iterations < = 0 );
510
510
511
511
// Start issuing the first group of the next stage outside of the mainloop
512
512
copy_tiles_and_advance (iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
@@ -611,10 +611,10 @@ class MmaPlanarComplexMultistage :
611
611
}
612
612
613
613
--gemm_k_iterations;
614
- iterator_A_real.clear_mask (gemm_k_iterations = = 0 );
615
- iterator_A_imag.clear_mask (gemm_k_iterations = = 0 );
616
- iterator_B_real.clear_mask (gemm_k_iterations = = 0 );
617
- iterator_B_imag.clear_mask (gemm_k_iterations = = 0 );
614
+ iterator_A_real.clear_mask (gemm_k_iterations < = 0 );
615
+ iterator_A_imag.clear_mask (gemm_k_iterations < = 0 );
616
+ iterator_B_real.clear_mask (gemm_k_iterations < = 0 );
617
+ iterator_B_imag.clear_mask (gemm_k_iterations < = 0 );
618
618
}
619
619
620
620
warp_mma_planar_complex (
0 commit comments