Skip to content

Commit

Permalink
create MaxNReg and Return kir nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Dec 20, 2024
1 parent 128479b commit 952cb0a
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 0 deletions.
10 changes: 10 additions & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) {
pushBack(const_cast<kir::WgMmaFence*>(fence)); // NOLINT
}

void IndexLowering::handle(const kir::MaxNReg* maxnreg) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::MaxNReg*>(maxnreg)); // NOLINT
}

void IndexLowering::handle(const kir::Return* ret) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::Return*>(ret)); // NOLINT
}

void IndexLowering::handle(const kir::AsyncCommit* commit) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::AsyncCommit*>(commit)); // NOLINT
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::FenceAsyncProxy*) final;
void handle(const kir::WgMmaFence*) final;
void handle(const kir::MaxNReg*) final;
void handle(const kir::Return*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArrive*) final;
Expand Down
13 changes: 13 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator {
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}

void handle(kir::MaxNReg* maxnreg) final {
std::string ptx = (maxnreg->increaseRegisters())
? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
registerReplace(
maxnreg,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{maxnreg->numberOfRegisters()},
kir::Asm::Options{/*volatile=*/true}));
}
};

std::vector<Expr*> lowerToInlinePtx(const std::vector<Expr*>& exprs) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class Val;
f(GridSync); \
f(FenceAsyncProxy); \
f(WgMmaFence); \
f(MaxNReg); \
f(Return); \
f(MBarrierInit); \
f(MBarrierInvalidate); \
f(MBarrierArrive); \
Expand Down
41 changes: 41 additions & 0 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,47 @@ std::string WgMmaFence::toInlineString(int indent_size) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(WgMmaFence)

MaxNReg::MaxNReg(
IrBuilderPasskey passkey,
Val* number_of_registers,
bool increase_registers)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
addInput(number_of_registers);
addDataAttribute(increase_registers);
}

std::string MaxNReg::toString(int indent_size) const {
return (increaseRegisters()) ? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
}

std::string MaxNReg::toInlineString(int indent_size) const {
NVF_CHECK(false, "MaxNReg can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MaxNReg)

Return::Return(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}

std::string Return::toString(int indent_size) const {
return "return";
}

std::string Return::toInlineString(int indent_size) const {
NVF_CHECK(false, "Return can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Return)

MBarrierInit::MBarrierInit(
IrBuilderPasskey passkey,
Val* mbarrier,
Expand Down
46 changes: 46 additions & 0 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BlockSync;
class GridSync;
class FenceAsyncProxy;
class WgMmaFence;
class MaxNReg;
class Return;
class MBarrierInit;
class MBarrierInvalidate;
class MBarrierArrive;
Expand Down Expand Up @@ -469,6 +471,50 @@ class WgMmaFence final : public Expr {
std::string toInlineString(int indent_size = 0) const override;
};

// PTX: setmaxnreg.inc.sync.aligned.u32 and setmaxnreg.dec.sync.aligned.u32
class MaxNReg final : public Expr {
public:
using Expr::Expr;

explicit MaxNReg(
IrBuilderPasskey passkey,
Val* number_of_registers,
bool increase_registers);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return (increaseRegisters()) ? "IncMaxNReg" : "DecMaxNReg";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

bool increaseRegisters() const {
return attribute<bool>(0);
}

Val* numberOfRegisters() const {
return input(0);
}
};

class Return final : public Expr {
public:
using Expr::Expr;

explicit Return(IrBuilderPasskey passkey);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "Return";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
};

class MBarrierInit final : public Expr {
public:
using Expr::Expr;
Expand Down

0 comments on commit 952cb0a

Please sign in to comment.