Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix matmul incorrect results when k dim for CTA tile is a multiple of 16 #3616

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
18 changes: 15 additions & 3 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
code_ << "__global__ void ";
// TODO Fix hardcoded values
code_ << "__launch_bounds__(384, 1) ";
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
Expand Down Expand Up @@ -3325,10 +3327,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
// Use a custom synchronization method if enabled
if (getNvFuserEnv("USE_BLOCK_SYNC_ATOMIC")) {
indent() << "block_sync::sync();\n";
} else if (isAligned()) {
indent() << "__syncthreads();\n";
} else {
indent() << "__barrier_sync(0);\n";
ArgumentBuilder sync_call_template_parms;
sync_call_template_parms.arg(isAligned());

ArgumentBuilder sync_call_args;
sync_call_args.arg(genComputeBlockDim());

auto sync_call =
genCall("block_sync::sync", sync_call_template_parms, sync_call_args);
indent() << sync_call << ";\n";
}
}

Expand Down Expand Up @@ -3505,6 +3513,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n";
}

void handle(const kir::Return* ret) final {
indent() << "return;\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class AllocationInserter : public kir::ExprMutator {
// generic-async proxy fence and wgmma fence before each mma
// instruction. For this case, we need to insert these fences
// after the initialization of the accumulator, so that the
// inilization is visible to the async proxy.
// initialization is visible to the async proxy.
// When all inputs are guarded by mbarrier, we will insert these
// fences before each mma instruction, so there is no need to
// insert them after the initialization of the accumulator here.
Expand Down
13 changes: 13 additions & 0 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,11 +1394,24 @@ class CircularBufferInserter : private kir::ExprMutator {
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

kir::MaxNReg* dec_reg_load_warp = IrBuilder::create<kir::MaxNReg>(
IrBuilder::create<Val>(24, DataType::Index),
/*increase_registers=*/false);
warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp);

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

kir::Return* ret = IrBuilder::create<kir::Return>();
warp_dispatch_ite->thenBody().push_back(ret);

kir::MaxNReg* inc_reg_load_warp = IrBuilder::create<kir::MaxNReg>(
IrBuilder::create<Val>(240, DataType::Index),
/*increase_registers*/ true);
warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp);

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
warp_dispatch_ite->elseBody().push_back(prefetch_loop);
Expand Down
10 changes: 10 additions & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) {
pushBack(const_cast<kir::WgMmaFence*>(fence)); // NOLINT
}

void IndexLowering::handle(const kir::MaxNReg* maxnreg) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::MaxNReg*>(maxnreg)); // NOLINT
}

void IndexLowering::handle(const kir::Return* ret) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::Return*>(ret)); // NOLINT
}

void IndexLowering::handle(const kir::AsyncCommit* commit) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::AsyncCommit*>(commit)); // NOLINT
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::FenceAsyncProxy*) final;
void handle(const kir::WgMmaFence*) final;
void handle(const kir::MaxNReg*) final;
void handle(const kir::Return*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArrive*) final;
Expand Down
13 changes: 13 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator {
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}

void handle(kir::MaxNReg* maxnreg) final {
std::string ptx = (maxnreg->increaseRegisters())
? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
registerReplace(
maxnreg,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{maxnreg->numberOfRegisters()},
kir::Asm::Options{/*volatile=*/true}));
}
};

std::vector<Expr*> lowerToInlinePtx(const std::vector<Expr*>& exprs) {
Expand Down
21 changes: 18 additions & 3 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
if (auto mma = dynamic_cast<MmaOp*>(expr)) {
if (mma->isHopper()) {
auto scope = scope_.empty() ? nullptr : scope_.back();
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();
registerInsertBefore(expr, wgmma_fence, scope);
if (!lower_utils::allMmaInputsGuardedByMBarrier(mma)) {
// Makes sure that writes to operands in the generic proxy are visible
// to the async proxy
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();
registerInsertBefore(expr, wgmma_fence, scope);
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
registerInsertBefore(expr, fence_async, scope);
}
Expand Down Expand Up @@ -782,7 +782,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
}
};

// Insert wait expressions for WAR harzard for async operations such as wgmma
// Insert wait expressions for WAR hazard for async operations such as wgmma
// and tma store. To do so, we find the structure like the following example:
// for 1
// for 2
Expand Down Expand Up @@ -1007,6 +1007,21 @@ class WarAsyncWaitInserter : private kir::ExprMutator {
// Process the expressions in the for loop
kir::ExprMutator::handle(for_loop);

// NOTE Warp Specialization require WAR wgmma sync before launching next tma
// load
if (for_loop->circularBufferLoopStage() ==
CircularBufferLoopStage::ComputeWarp) {
for (Expr* expr : for_loop->body().exprs()) {
if (expr->isA<kir::MBarrierArrive>()) {
auto sync_exprs = lower_utils::getSyncExprs(AsyncOpType::WgMma, 0);
while (!sync_exprs.empty()) {
registerInsertBefore(expr, sync_exprs.back(), &for_loop->body());
sync_exprs.pop_back();
}
}
}
}

// Insert async wait at the end of this for loop
if (within_iter_loop_) {
std::unordered_map<AsyncOpType, int64_t> types_and_pending_ops_to_protect;
Expand Down
5 changes: 5 additions & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,11 @@ std::vector<Expr*> getSyncExprs(AsyncOpType async_type, int64_t keep_stages) {
sync_exprs.push_back(commit);
auto wait = IrBuilder::create<kir::AsyncWait>(async_type, keep_stages);
sync_exprs.push_back(wait);
// TODO Do not apply for warp specialization
if (async_type == AsyncOpType::WgMma) {
auto sync = IrBuilder::create<kir::BlockSync>(true);
sync_exprs.push_back(sync);
}
return sync_exprs;
}

Expand Down
2 changes: 2 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class Val;
f(GridSync); \
f(FenceAsyncProxy); \
f(WgMmaFence); \
f(MaxNReg); \
f(Return); \
f(MBarrierInit); \
f(MBarrierInvalidate); \
f(MBarrierArrive); \
Expand Down
41 changes: 41 additions & 0 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,47 @@ std::string WgMmaFence::toInlineString(int indent_size) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(WgMmaFence)

MaxNReg::MaxNReg(
IrBuilderPasskey passkey,
Val* number_of_registers,
bool increase_registers)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
addInput(number_of_registers);
addDataAttribute(increase_registers);
}

std::string MaxNReg::toString(int indent_size) const {
return (increaseRegisters()) ? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
}

std::string MaxNReg::toInlineString(int indent_size) const {
NVF_CHECK(false, "MaxNReg can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MaxNReg)

Return::Return(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}

std::string Return::toString(int indent_size) const {
return "return";
}

std::string Return::toInlineString(int indent_size) const {
NVF_CHECK(false, "Return can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Return)

MBarrierInit::MBarrierInit(
IrBuilderPasskey passkey,
Val* mbarrier,
Expand Down
46 changes: 46 additions & 0 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BlockSync;
class GridSync;
class FenceAsyncProxy;
class WgMmaFence;
class MaxNReg;
class Return;
class MBarrierInit;
class MBarrierInvalidate;
class MBarrierArrive;
Expand Down Expand Up @@ -469,6 +471,50 @@ class WgMmaFence final : public Expr {
std::string toInlineString(int indent_size = 0) const override;
};

// PTX: setmaxnreg.inc.sync.aligned.u32 and setmaxnreg.dec.sync.aligned.u32
class MaxNReg final : public Expr {
public:
using Expr::Expr;

explicit MaxNReg(
IrBuilderPasskey passkey,
Val* number_of_registers,
bool increase_registers);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return (increaseRegisters()) ? "IncMaxNReg" : "DecMaxNReg";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

bool increaseRegisters() const {
return attribute<bool>(0);
}

Val* numberOfRegisters() const {
return input(0);
}
};

class Return final : public Expr {
public:
using Expr::Expr;

explicit Return(IrBuilderPasskey passkey);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "Return";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
};

class MBarrierInit final : public Expr {
public:
using Expr::Expr;
Expand Down
45 changes: 27 additions & 18 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,18 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
TensorView* tv,
bool is_mma_result) {
void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
// TODO Add constraints

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}

Expand Down Expand Up @@ -424,7 +418,20 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

transformLikeMmaOutput(mma_result, /*is_mma_result=*/true);
// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
Expand Down Expand Up @@ -459,7 +466,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);
transformLikeMmaOutput(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -536,7 +543,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
transformLikeMmaOutput(tv);
}

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
Expand Down Expand Up @@ -570,7 +577,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false);
transformLikeMmaOutput(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down Expand Up @@ -626,15 +633,17 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
.smem_circular_buffer_prefetch_gap,
WarpSpecialized(ParallelType::TIDy));
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
.smem_circular_buffer_prefetch_gap,
WarpSpecialized(ParallelType::TIDy));
}
}

Expand Down
Loading
Loading