Skip to content

Commit 99f492a

Browse files
committed
create MaxNReg and Return kir nodes
1 parent fca0086 commit 99f492a

File tree

7 files changed

+118
-0
lines changed

7 files changed

+118
-0
lines changed

csrc/codegen.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -3505,6 +3505,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
35053505
indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n";
35063506
}
35073507

3508+
void handle(const kir::Return* ret) final {
3509+
indent() << "return;\n";
3510+
}
3511+
35083512
private:
35093513
std::stringstream code_;
35103514
const kir::Kernel* kernel_;

csrc/device_lower/pass/index.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) {
25832583
pushBack(const_cast<kir::WgMmaFence*>(fence)); // NOLINT
25842584
}
25852585

2586+
void IndexLowering::handle(const kir::MaxNReg* maxnreg) {
2587+
// TODO(kir): remove the need for const_cast
2588+
pushBack(const_cast<kir::MaxNReg*>(maxnreg)); // NOLINT
2589+
}
2590+
2591+
void IndexLowering::handle(const kir::Return* ret) {
2592+
// TODO(kir): remove the need for const_cast
2593+
pushBack(const_cast<kir::Return*>(ret)); // NOLINT
2594+
}
2595+
25862596
void IndexLowering::handle(const kir::AsyncCommit* commit) {
25872597
// TODO(kir): remove the need for const_cast
25882598
pushBack(const_cast<kir::AsyncCommit*>(commit)); // NOLINT

csrc/device_lower/pass/index.h

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch {
7575
void handle(const kir::GridSync*) final;
7676
void handle(const kir::FenceAsyncProxy*) final;
7777
void handle(const kir::WgMmaFence*) final;
78+
void handle(const kir::MaxNReg*) final;
79+
void handle(const kir::Return*) final;
7880
void handle(const kir::MBarrierInit*) final;
7981
void handle(const kir::MBarrierInvalidate*) final;
8082
void handle(const kir::MBarrierArrive*) final;

csrc/device_lower/pass/inline_ptx.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator {
272272
std::vector<Val*>{},
273273
kir::Asm::Options{/*volatile=*/true}));
274274
}
275+
276+
void handle(kir::MaxNReg* maxnreg) final {
277+
std::string ptx = (maxnreg->increaseRegisters())
278+
? "setmaxnreg.inc.sync.aligned.u32"
279+
: "setmaxnreg.dec.sync.aligned.u32";
280+
registerReplace(
281+
maxnreg,
282+
IrBuilder::create<kir::Asm>(
283+
ptx,
284+
std::vector<Val*>{},
285+
std::vector<Val*>{maxnreg->numberOfRegisters()},
286+
kir::Asm::Options{/*volatile=*/true}));
287+
}
275288
};
276289

277290
std::vector<Expr*> lowerToInlinePtx(const std::vector<Expr*>& exprs) {

csrc/dispatch.h

+2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ class Val;
120120
f(GridSync); \
121121
f(FenceAsyncProxy); \
122122
f(WgMmaFence); \
123+
f(MaxNReg); \
124+
f(Return); \
123125
f(MBarrierInit); \
124126
f(MBarrierInvalidate); \
125127
f(MBarrierArrive); \

csrc/kernel_ir.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,47 @@ std::string WgMmaFence::toInlineString(int indent_size) const {
485485

486486
NVFUSER_DEFINE_CLONE_AND_CREATE(WgMmaFence)
487487

488+
MaxNReg::MaxNReg(
489+
IrBuilderPasskey passkey,
490+
Val* number_of_registers,
491+
bool increase_registers)
492+
: Expr(passkey) {
493+
NVF_ERROR(passkey.ir_container_ != nullptr);
494+
NVF_ERROR(
495+
passkey.ir_container_->isA<kir::Kernel>(),
496+
"IR type only valid for Kernel container.");
497+
addInput(number_of_registers);
498+
addDataAttribute(increase_registers);
499+
}
500+
501+
std::string MaxNReg::toString(int indent_size) const {
502+
return (increaseRegisters()) ? "setmaxnreg.inc.sync.aligned.u32"
503+
: "setmaxnreg.dec.sync.aligned.u32";
504+
}
505+
506+
std::string MaxNReg::toInlineString(int indent_size) const {
507+
NVF_CHECK(false, "MaxNReg can not be printed inline");
508+
}
509+
510+
NVFUSER_DEFINE_CLONE_AND_CREATE(MaxNReg)
511+
512+
Return::Return(IrBuilderPasskey passkey) : Expr(passkey) {
513+
NVF_ERROR(passkey.ir_container_ != nullptr);
514+
NVF_ERROR(
515+
passkey.ir_container_->isA<kir::Kernel>(),
516+
"IR type only valid for Kernel container.");
517+
}
518+
519+
std::string Return::toString(int indent_size) const {
520+
return "return";
521+
}
522+
523+
std::string Return::toInlineString(int indent_size) const {
524+
NVF_CHECK(false, "Return can not be printed inline");
525+
}
526+
527+
NVFUSER_DEFINE_CLONE_AND_CREATE(Return)
528+
488529
MBarrierInit::MBarrierInit(
489530
IrBuilderPasskey passkey,
490531
Val* mbarrier,

csrc/kernel_ir.h

+46
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class BlockSync;
4141
class GridSync;
4242
class FenceAsyncProxy;
4343
class WgMmaFence;
44+
class MaxNReg;
45+
class Return;
4446
class MBarrierInit;
4547
class MBarrierInvalidate;
4648
class MBarrierArrive;
@@ -469,6 +471,50 @@ class WgMmaFence final : public Expr {
469471
std::string toInlineString(int indent_size = 0) const override;
470472
};
471473

474+
// PTX: setmaxnreg.inc.sync.aligned.u32 and setmaxnreg.dec.sync.aligned.u32
475+
class MaxNReg final : public Expr {
476+
public:
477+
using Expr::Expr;
478+
479+
explicit MaxNReg(
480+
IrBuilderPasskey passkey,
481+
Val* number_of_registers,
482+
bool increase_registers);
483+
484+
NVFUSER_DECLARE_CLONE_AND_CREATE
485+
486+
const char* getOpString() const override {
487+
return (increaseRegisters()) ? "IncMaxNReg" : "DecMaxNReg";
488+
}
489+
490+
std::string toString(int indent_size = 0) const override;
491+
std::string toInlineString(int indent_size = 0) const override;
492+
493+
bool increaseRegisters() const {
494+
return attribute<bool>(0);
495+
}
496+
497+
Val* numberOfRegisters() const {
498+
return input(0);
499+
}
500+
};
501+
502+
class Return final : public Expr {
503+
public:
504+
using Expr::Expr;
505+
506+
explicit Return(IrBuilderPasskey passkey);
507+
508+
NVFUSER_DECLARE_CLONE_AND_CREATE
509+
510+
const char* getOpString() const override {
511+
return "Return";
512+
}
513+
514+
std::string toString(int indent_size = 0) const override;
515+
std::string toInlineString(int indent_size = 0) const override;
516+
};
517+
472518
class MBarrierInit final : public Expr {
473519
public:
474520
using Expr::Expr;

0 commit comments

Comments
 (0)