diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 7352e7fc96e..df9b0bf50c9 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -320,6 +320,8 @@ class NVF_API UnaryOp : public Expr { return "UnaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -358,6 +360,8 @@ class NVF_API BinaryOp : public Expr { return "BinaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -405,6 +409,8 @@ class TernaryOp : public Expr { return "TernaryOp"; } + std::string getGraphvizLabel() const override; + std::vector evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const override; @@ -1445,15 +1451,15 @@ class NVF_API MmaOp : public Expr { return attribute(ATTR_POS_MACRO); } - int m() const { + int64_t m() const { return getM(macro()); } - int n() const { + int64_t n() const { return getN(macro()); } - int k() const { + int64_t k() const { return getK(macro()); } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 33b5ce0b345..fd021239f93 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -557,6 +557,12 @@ std::string UnaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string UnaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getUnaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(UnaryOp) BinaryOp::BinaryOp( @@ -724,6 +730,12 @@ std::string BinaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string BinaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getBinaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(BinaryOp) TernaryOp::TernaryOp( @@ -825,6 +837,12 @@ std::string TernaryOp::toInlineString(int indent_size) const { return ss.str(); } +std::string TernaryOp::getGraphvizLabel() const { + std::stringstream ss; + ss << getOpString() << "(" << getTernaryOpType() << ")"; + return ss.str(); +} + NVFUSER_DEFINE_CLONE_AND_CREATE(TernaryOp) ArrayConstruct::ArrayConstruct(