Skip to content

Commit

Permalink
Adding support for scheduling the epilogue computation when smem_epil…
Browse files Browse the repository at this point in the history
…ogue parameter is true (#3581)

Refactoring some code and adding some support for smem_epilogue
TODO: add support for smem_epilogue when the output of the mma op is not
cast down to half precision.
  • Loading branch information
protonu authored Dec 18, 2024
1 parent 7cf313a commit c31b919
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
29 changes: 19 additions & 10 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,27 +520,36 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
// TODO: extend support when mma is not cast to half
NVF_ERROR(
dc->dtype() == DataType::Half,
"We support smem_epilogue on hopper only when the output of mma is cast to half");

d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
// Apply the common transforms to dc, d_smem, d
// After these transforms we schedule the inner two non-reduction loops
// (instruction tile) of dc and propagate is back till the outputs of mma.
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand Down
4 changes: 0 additions & 4 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3379,10 +3379,6 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
// TODO: Remove this test once the architecture agnostic can be
// run on hopper.
TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
if (use_smem_epilogue) {
GTEST_SKIP()
<< "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now";
}
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
Expand Down

0 comments on commit c31b919

Please sign in to comment.