-
My question is just as in the title. To elaborate, in the middle of a As an illustrative example, I attach a relevant code block below. I apologize that the code block is long, but it is necessary to provide the full context for the The early stopping conditions are checked and executed around the lines: still_going = true;
compute(stage_idx, i, &still_going); // GEMM, also sets the value of still_going
if (!still_going) {
break;
} [CODE] (long) using PipelineTmaAsync = cutlass::PipelineTmaAsync<NUM_STAGES>;
using PipelineState = cutlass::PipelineState<NUM_STAGES>;
using BarrierType = typename PipelineTmaAsync::ProducerBarrierType;
static constexpr auto num_consumers = cute::thr_size(TiledMma{});
auto pipeline_params = typename PipelineTmaAsync::Params{};
pipeline_params.transaction_bytes = tma_size_bytes;
pipeline_params.role = PipelineTmaAsync::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = threadIdx.x == 0;
pipeline_params.num_consumers = num_consumers;
auto pipeline = PipelineTmaAsync{shared_storage.pipeline, pipeline_params, ClusterShape{}};
auto smem_pipe_read = PipelineState{};
auto smem_pipe_write = cutlass::make_producer_start_state<PipelineTmaAsync>();
const auto num_blocks_tma_prologue = cute::min(num_blocks, NUM_STAGES);
const auto num_blocks_mma_prologue = cute::min(1, num_blocks_tma_prologue);
const auto num_blocks_mma_mainloop = num_blocks - num_blocks_mma_prologue;
/********************************************************************
* `still_going` tracks whether an early-stopping condition is met. *
********************************************************************/
bool still_going = false;
int block_idx = 0;
// TMA Prologue
CUTE_UNROLL
for (int i = 0; i < num_blocks_tma_prologue; ++i) {
pipeline.producer_acquire(smem_pipe_write);
auto stage_idx = smem_pipe_write.index();
auto tma_mbar = pipeline.producer_get_barrier(smem_pipe_write);
fetch_data(tma_mbar, i, stage_idx); // involves a TMA load
pipeline.producer_commit(smem_pipe_write, tma_size_bytes);
++smem_pipe_write;
}
block_idx += num_blocks_tma_prologue;
// MMA Prologue
CUTE_NO_UNROLL
for (int i = 0; i < num_blocks_mma_prologue; ++i) {
pipeline.consumer_wait(smem_pipe_read);
auto stage_idx = smem_pipe_read.index();
still_going = true;
compute(stage_idx, i, &still_going); // GEMM, also sets the value of still_going
if (!still_going) {
break;
}
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
// Main loop: MMA and TMA.
CUTE_NO_UNROLL
for (int i = 0; i < num_blocks_mma_mainloop; ++i) {
pipeline.consumer_wait(smem_pipe_read);
auto stage_idx = smem_pipe_read.index();
still_going = false;
compute(stage_idx, i, &still_going); // GEMM, also sets the value of still_going
if (!still_going) {
break;
}
// next read stage
if (block_idx < num_blocks) {
pipeline.producer_acquire(smem_pipe_write);
auto stage_idx = smem_pipe_write.index();
auto tma_mbar = pipeline.producer_get_barrier(smem_pipe_write);
fetch_data(tma_mbar, block_idx, stage_idx); // involves a TMA load
pipeline.producer_commit(smem_pipe_write, tma_size_bytes);
++smem_pipe_write;
++block_idx;
}
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
// Wait on all GMMAs
cute::warpgroup_wait<0>();
cute::warpgroup_fence_operand(rO);
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_sync();
} else {
__syncthreads();
} Truly appreciate your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Yes this is fine in general. Ideally we would model this as some kind of while loop around an updating k tile counter etc. You just have to be really careful to make sure the pipeline states for producers and consumers agree if you terminate early in case this is a persistent kernel or you are fusing with another collective later in the lifetime of kernel.
|
Beta Was this translation helpful? Give feedback.
Yes this is fine in general. Ideally we would model this as some kind of while loop around an updating k tile counter etc. You just have to be really careful to make sure the pipeline states for producers and consumers agree if you terminate early in case this is a persistent kernel or you are fusing with another collective later in the lifetime of kernel.
still_going
for each other. Is that happening insidecompute()
?