Skip to content

Commit

Permalink
Show operator name with unary and binary nodes in graphviz dot outputs (
Browse files Browse the repository at this point in the history
#3630)

I think this is a bit more convenient.
  • Loading branch information
naoyam authored Dec 21, 2024
1 parent 3268d79 commit 516d590
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
12 changes: 9 additions & 3 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class NVF_API UnaryOp : public Expr {
return "UnaryOp";
}

std::string getGraphvizLabel() const override;

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
Expand Down Expand Up @@ -358,6 +360,8 @@ class NVF_API BinaryOp : public Expr {
return "BinaryOp";
}

std::string getGraphvizLabel() const override;

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
Expand Down Expand Up @@ -405,6 +409,8 @@ class TernaryOp : public Expr {
return "TernaryOp";
}

std::string getGraphvizLabel() const override;

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
Expand Down Expand Up @@ -1445,15 +1451,15 @@ class NVF_API MmaOp : public Expr {
return attribute<MmaMacro>(ATTR_POS_MACRO);
}

int m() const {
int64_t m() const {
return getM(macro());
}

int n() const {
int64_t n() const {
return getN(macro());
}

int k() const {
int64_t k() const {
return getK(macro());
}

Expand Down
18 changes: 18 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,12 @@ std::string UnaryOp::toInlineString(int indent_size) const {
return ss.str();
}

std::string UnaryOp::getGraphvizLabel() const {
std::stringstream ss;
ss << getOpString() << "(" << getUnaryOpType() << ")";
return ss.str();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(UnaryOp)

BinaryOp::BinaryOp(
Expand Down Expand Up @@ -724,6 +730,12 @@ std::string BinaryOp::toInlineString(int indent_size) const {
return ss.str();
}

std::string BinaryOp::getGraphvizLabel() const {
std::stringstream ss;
ss << getOpString() << "(" << getBinaryOpType() << ")";
return ss.str();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(BinaryOp)

TernaryOp::TernaryOp(
Expand Down Expand Up @@ -825,6 +837,12 @@ std::string TernaryOp::toInlineString(int indent_size) const {
return ss.str();
}

std::string TernaryOp::getGraphvizLabel() const {
std::stringstream ss;
ss << getOpString() << "(" << getTernaryOpType() << ")";
return ss.str();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(TernaryOp)

ArrayConstruct::ArrayConstruct(
Expand Down

0 comments on commit 516d590

Please sign in to comment.