Skip to content

Commit

Permalink
Fix illegal memory accesses in multistage Mma's for k=0
Browse files Browse the repository at this point in the history
  • Loading branch information
dfyz authored and Ivan Komarov committed Jun 18, 2024
1 parent 637b159 commit f3dbff5
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -559,8 +559,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -725,8 +725,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
18 changes: 9 additions & 9 deletions examples/45_dual_gemm/threadblock/dual_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ class DualMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -555,9 +555,9 @@ class DualMmaMultistage :
++this->warp_tile_iterator_B0_;
++this->warp_tile_iterator_B1_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -730,9 +730,9 @@ class DualMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/ell_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ class EllMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -456,8 +456,8 @@ class EllMmaMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

if (is_A_sparse){
iterator_A.ell_add_mask(ell_iterator.get_blocksize());
Expand Down Expand Up @@ -608,8 +608,8 @@ class EllMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_blas3_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ class MmaBlas3Multistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -519,8 +519,8 @@ class MmaBlas3Multistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -661,8 +661,8 @@ class MmaBlas3Multistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ class MmaLayernormMainloopFusionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -692,9 +692,9 @@ class MmaLayernormMainloopFusionMultistage :
++this->warp_tile_iterator_A_gamma_beta_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -824,9 +824,9 @@ class MmaLayernormMainloopFusionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ class MmaMultistage :
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {

// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -588,8 +588,8 @@ class MmaMultistage :

// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// The last warp-tile also converts the shared memory fragments used by
Expand Down Expand Up @@ -620,8 +620,8 @@ class MmaMultistage :
PipeState pipe_state;

// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
Expand Down
24 changes: 12 additions & 12 deletions include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class MmaPlanarComplexMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);

iterator_A_real.set_iteration_index(0);
iterator_A_imag.set_iteration_index(0);
Expand Down Expand Up @@ -503,10 +503,10 @@ class MmaPlanarComplexMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);

// Start issuing the first group of the next stage outside of the mainloop
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
Expand Down Expand Up @@ -611,10 +611,10 @@ class MmaPlanarComplexMultistage :
}

--gemm_k_iterations;
iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);
}

warp_mma_planar_complex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ class MmaSoftmaxMainloopFusionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -581,8 +581,8 @@ class MmaSoftmaxMainloopFusionMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

// Start issuing the first group of the next stage outside of the mainloop
copy_tiles_and_advance(iterator_A, iterator_B);
Expand Down Expand Up @@ -708,8 +708,8 @@ class MmaSoftmaxMainloopFusionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
18 changes: 9 additions & 9 deletions include/cutlass/gemm/threadblock/mma_sparse_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ class SparseMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -499,9 +499,9 @@ class SparseMmaMultistage :
++this->warp_tile_iterator_B_;
++this->warp_tile_iterator_E_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -634,9 +634,9 @@ class SparseMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ class MmaWithReductionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -403,8 +403,8 @@ class MmaWithReductionMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -513,8 +513,8 @@ class MmaWithReductionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
Loading

0 comments on commit f3dbff5

Please sign in to comment.