diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 0060e626fe6..e6b927f351c 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -274,6 +274,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { code_ << "__global__ void "; + // TODO Fix hardcoded values + code_ << "__launch_bounds__(384, 1) "; if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = kernel_->getManaged>( @@ -3325,10 +3327,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Use a custom synchronization method if enabled if (getNvFuserEnv("USE_BLOCK_SYNC_ATOMIC")) { indent() << "block_sync::sync();\n"; - } else if (isAligned()) { - indent() << "__syncthreads();\n"; } else { - indent() << "__barrier_sync(0);\n"; + ArgumentBuilder sync_call_template_parms; + sync_call_template_parms.arg(isAligned()); + + ArgumentBuilder sync_call_args; + sync_call_args.arg(genComputeBlockDim()); + + auto sync_call = + genCall("block_sync::sync", sync_call_template_parms, sync_call_args); + indent() << sync_call << ";\n"; } } @@ -3505,6 +3513,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n"; } + void handle(const kir::Return* ret) final { + indent() << "return;\n"; + } + private: std::stringstream code_; const kir::Kernel* kernel_; diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index bee61c46873..1242ddc794b 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -601,7 +601,7 @@ class AllocationInserter : public kir::ExprMutator { // generic-async proxy fence and wgmma fence before each mma // instruction. For this case, we need to insert these fences // after the initialization of the accumulator, so that the - // inilization is visible to the async proxy. + // initialization is visible to the async proxy. // When all inputs are guarded by mbarrier, we will insert these // fences before each mma instruction, so there is no need to // insert them after the initialization of the accumulator here. diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 15ed808b936..3b0f5dd7e34 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1394,11 +1394,24 @@ class CircularBufferInserter : private kir::ExprMutator { warp_specialize_on), circular_buffer_loop->fusion()->oneVal())))); + kir::MaxNReg* dec_reg_load_warp = IrBuilder::create( + IrBuilder::create(24, DataType::Index), + /*increase_registers=*/false); + warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + // Load loop: ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp); warp_dispatch_ite->thenBody().push_back(load_loop); + kir::Return* ret = IrBuilder::create(); + warp_dispatch_ite->thenBody().push_back(ret); + + kir::MaxNReg* inc_reg_load_warp = IrBuilder::create( + IrBuilder::create(240, DataType::Index), + /*increase_registers*/ true); + warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + // Prefetch: auto prefetch_loop = createArrivesForWar(circular_buffer_loop); warp_dispatch_ite->elseBody().push_back(prefetch_loop); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index f4121021f3b..32b8ba3938b 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) { pushBack(const_cast(fence)); // NOLINT } +void IndexLowering::handle(const kir::MaxNReg* maxnreg) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(maxnreg)); // NOLINT +} + +void IndexLowering::handle(const kir::Return* ret) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(ret)); // NOLINT +} + void IndexLowering::handle(const kir::AsyncCommit* commit) { // TODO(kir): remove the need for const_cast pushBack(const_cast(commit)); // NOLINT diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 8d206159128..03e281dc868 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::GridSync*) final; void handle(const kir::FenceAsyncProxy*) final; void handle(const kir::WgMmaFence*) final; + void handle(const kir::MaxNReg*) final; + void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; void handle(const kir::MBarrierArrive*) final; diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 31afc58a775..3a99ecda938 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator { std::vector{}, kir::Asm::Options{/*volatile=*/true})); } + + void handle(kir::MaxNReg* maxnreg) final { + std::string ptx = (maxnreg->increaseRegisters()) + ? "setmaxnreg.inc.sync.aligned.u32" + : "setmaxnreg.dec.sync.aligned.u32"; + registerReplace( + maxnreg, + IrBuilder::create( + ptx, + std::vector{}, + std::vector{maxnreg->numberOfRegisters()}, + kir::Asm::Options{/*volatile=*/true})); + } }; std::vector lowerToInlinePtx(const std::vector& exprs) { diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 4e2f55323be..5ced19589e5 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -393,11 +393,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { if (auto mma = dynamic_cast(expr)) { if (mma->isHopper()) { auto scope = scope_.empty() ? nullptr : scope_.back(); + auto wgmma_fence = IrBuilder::create(); + registerInsertBefore(expr, wgmma_fence, scope); if (!lower_utils::allMmaInputsGuardedByMBarrier(mma)) { // Makes sure that writes to operands in the generic proxy are visible // to the async proxy - auto wgmma_fence = IrBuilder::create(); - registerInsertBefore(expr, wgmma_fence, scope); auto fence_async = IrBuilder::create(); registerInsertBefore(expr, fence_async, scope); } @@ -782,7 +782,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { } }; -// Insert wait expressions for WAR harzard for async operations such as wgmma +// Insert wait expressions for WAR hazard for async operations such as wgmma // and tma store. To do so, we find the structure like the following example: // for 1 // for 2 @@ -1007,6 +1007,21 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Process the expressions in the for loop kir::ExprMutator::handle(for_loop); + // NOTE Warp Specialization require WAR wgmma sync before launching next tma + // load + if (for_loop->circularBufferLoopStage() == + CircularBufferLoopStage::ComputeWarp) { + for (Expr* expr : for_loop->body().exprs()) { + if (expr->isA()) { + auto sync_exprs = lower_utils::getSyncExprs(AsyncOpType::WgMma, 0); + while (!sync_exprs.empty()) { + registerInsertBefore(expr, sync_exprs.back(), &for_loop->body()); + sync_exprs.pop_back(); + } + } + } + } + // Insert async wait at the end of this for loop if (within_iter_loop_) { std::unordered_map types_and_pending_ops_to_protect; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 35b825d5348..4a56b814d72 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1995,6 +1995,11 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { sync_exprs.push_back(commit); auto wait = IrBuilder::create(async_type, keep_stages); sync_exprs.push_back(wait); + // TODO Do not apply for warp specialization + if (async_type == AsyncOpType::WgMma) { + auto sync = IrBuilder::create(true); + sync_exprs.push_back(sync); + } return sync_exprs; } diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..7f7fce03e4e 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -120,6 +120,8 @@ class Val; f(GridSync); \ f(FenceAsyncProxy); \ f(WgMmaFence); \ + f(MaxNReg); \ + f(Return); \ f(MBarrierInit); \ f(MBarrierInvalidate); \ f(MBarrierArrive); \ diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index fc464eac315..b04297583ca 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -485,6 +485,47 @@ std::string WgMmaFence::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(WgMmaFence) +MaxNReg::MaxNReg( + IrBuilderPasskey passkey, + Val* number_of_registers, + bool increase_registers) + : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + addInput(number_of_registers); + addDataAttribute(increase_registers); +} + +std::string MaxNReg::toString(int indent_size) const { + return (increaseRegisters()) ? "setmaxnreg.inc.sync.aligned.u32" + : "setmaxnreg.dec.sync.aligned.u32"; +} + +std::string MaxNReg::toInlineString(int indent_size) const { + NVF_CHECK(false, "MaxNReg can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MaxNReg) + +Return::Return(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +std::string Return::toString(int indent_size) const { + return "return"; +} + +std::string Return::toInlineString(int indent_size) const { + NVF_CHECK(false, "Return can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Return) + MBarrierInit::MBarrierInit( IrBuilderPasskey passkey, Val* mbarrier, diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 60421db1995..cf3e50ed8e6 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -41,6 +41,8 @@ class BlockSync; class GridSync; class FenceAsyncProxy; class WgMmaFence; +class MaxNReg; +class Return; class MBarrierInit; class MBarrierInvalidate; class MBarrierArrive; @@ -469,6 +471,50 @@ class WgMmaFence final : public Expr { std::string toInlineString(int indent_size = 0) const override; }; +// PTX: setmaxnreg.inc.sync.aligned.u32 and setmaxnreg.dec.sync.aligned.u32 +class MaxNReg final : public Expr { + public: + using Expr::Expr; + + explicit MaxNReg( + IrBuilderPasskey passkey, + Val* number_of_registers, + bool increase_registers); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return (increaseRegisters()) ? "IncMaxNReg" : "DecMaxNReg"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + bool increaseRegisters() const { + return attribute(0); + } + + Val* numberOfRegisters() const { + return input(0); + } +}; + +class Return final : public Expr { + public: + using Expr::Expr; + + explicit Return(IrBuilderPasskey passkey); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Return"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; +}; + class MBarrierInit final : public Expr { public: using Expr::Expr; diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index cedc7d262d5..ac9a476a508 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,24 +29,18 @@ namespace nvfuser { -void HopperMultipleMatmulScheduler::transformLikeMmaOutput( - TensorView* tv, - bool is_mma_result) { +void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) { // TODO Add constraints - auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { - return (is_mma_result) ? idx - 1 : idx; - }; - // Original: [..., Mo, No, Mi, Ni] - tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); - tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); + tv->split(-2, getM(params_->mma_macro)); + tv->split(-1, getN(params_->mma_macro)); // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}}); + tv->reorder({{-3, -2}}); // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - tv->merge(apply_k_dim_offset(-4)); + tv->merge(-4); // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy); + tv->axis(-3)->parallelize(ParallelType::TIDy); // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] } @@ -424,7 +418,20 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { splitk_sums_.push_back(splitk_sum); } - transformLikeMmaOutput(mma_result, /*is_mma_result=*/true); + // Original: [..., Mo, No, Mi, Ni, Ki] + mma_result->split(-3, getM(params_->mma_macro)); + mma_result->split(-2, getN(params_->mma_macro)); + mma_result->split(-1, getK(params_->mma_macro)); + // After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii] + mma_result->reorder({{-5, -4}}); + // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii] + mma_result->reorder({{-2, -4}}); + // After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii] + mma_result->merge(-6); + // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] + mma_result->axis(-5)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( mma_result->getLoopDomain()); mma_result->setAllocationDomain(s.as(), true); @@ -459,7 +466,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // op. blockTileTensors({d}); parallelizeBlocks({d}); - transformLikeMmaOutput(d, /*is_mma_result=*/false); + transformLikeMmaOutput(d); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( d->getLoopDomain()); @@ -536,7 +543,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { blockTileTensors(tvs_to_schedule); parallelizeBlocks(tvs_to_schedule); for (auto tv : tvs_to_schedule) { - transformLikeMmaOutput(tv, /*is_mma_result=*/false); + transformLikeMmaOutput(tv); } auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( @@ -570,7 +577,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { for (TensorView* splitk_sum : splitk_sums_) { // Always use serial grid reduction for split-K sum splitk_sum->definition()->as()->requestSerialGridReduction(); - transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false); + transformLikeMmaOutput(splitk_sum); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( splitk_sum->getLoopDomain()); splitk_sum->setLoopDomain(s.as()); @@ -626,7 +633,8 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { /*prefetch_distance=*/ params_->circular_buffer_options.smem_circular_buffer_stage - params_->circular_buffer_options - .smem_circular_buffer_prefetch_gap); + .smem_circular_buffer_prefetch_gap, + WarpSpecialized(ParallelType::TIDy)); } for (TensorView* bcw_smem : bcw_smems_) { bcw_smem->circularBuffer( @@ -634,7 +642,8 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { /*prefetch_distance=*/ params_->circular_buffer_options.smem_circular_buffer_stage - params_->circular_buffer_options - .smem_circular_buffer_prefetch_gap); + .smem_circular_buffer_prefetch_gap, + WarpSpecialized(ParallelType::TIDy)); } } diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 295b55ee96e..864bbb0f3fc 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -187,7 +187,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { // Schedule a block-tiled TensorView like mma output. // Why? WGMMA has a unique output format. TensorViews after the mma-result in // registers must respect this format for correctness. - void transformLikeMmaOutput(TensorView* tv, bool is_mma_result); + void transformLikeMmaOutput(TensorView* tv); private: std::vector canonical_dim_ordering_; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 9e9395c5e18..a0dbeb87777 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4027,8 +4027,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 256, 16); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4084,8 +4084,8 @@ TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 256, 16); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4147,8 +4147,8 @@ TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) { at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 256, 16); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4209,8 +4209,8 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf); MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 256, 16); + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index c860b908a92..83b4434b49a 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3273,13 +3273,13 @@ class HopperMatmulSchedulerTest // TODO cta tile is a multiple of mma macro for hopper. // Default cta_tile configuration is 2-CTA. gemm_tile.cta_tile = - GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro)); + GemmTile(2 * getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro)); // TODO warp tile is (macroM, macroN, macroK) for hopper. gemm_tile.warp_tile = - GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro)); + GemmTile(getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro)); - mparams.supported_vec_size = {8, 8, 4}; + mparams.supported_vec_size = {8, 8, 8}; mparams.mma_macro = mma_macro; @@ -3467,7 +3467,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Bool(), // b_k_inner testing::Values(512), // M testing::Values(256), // N - testing::Values(64), // K + testing::Values(128), // K testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros testing::Values(1, 2) // SplitK Factor ),