From 9afb35a73e933e2d5d7e03ac0a9192a881ff5282 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 13 Apr 2022 17:57:03 +0000 Subject: [PATCH 1/5] [Remat][Bugfix] Fix TensorAnalyzer to handle nested tuples. --- src/pass/liveness_analysis.cc | 37 ++++++++++++++++++++++++++++-- src/pass/rematerialization.cc | 43 +++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/pass/liveness_analysis.cc b/src/pass/liveness_analysis.cc index 18badbef..3fe8482b 100644 --- a/src/pass/liveness_analysis.cc +++ b/src/pass/liveness_analysis.cc @@ -351,10 +351,42 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) { return VarCreator(this).Run(type); } -/*! \brief Calculate the byte compact size of the given type. If the type is a tuple, +/*! + * \brief Calculate the byte compact size of the given type. If the type is a tuple, * then the size of each tensor in the tuple will be returned. Note that size 0 means - * a tensor with dynamic shape. + * a tensor with dynamic shape. If the input type contains nested tuples, the nested + * tuples are flattened and a flat vector of sizes is returned at the end. The order + * of the fields is preserved. */ +std::vector CalcBytesCompactSizes(const Type& type) { + std::vector ty_stack; + std::vector ttypes; + ty_stack.push_back(type); + while (!ty_stack.empty()) { + auto ty = ty_stack.back(); + ty_stack.pop_back(); + if (auto tuple_ty_node = ty.as()) { + // If the current type corresponds to a tuple, process the type of each field later + for (auto field : tuple_ty_node->fields) + ty_stack.push_back(field); + } else if (auto tensor_ty_node = ty.as()) { + // Tensor types are added to the final list and sizes will be calculated + ttypes.push_back(tensor_ty_node); + } else { + // Other types are not supported + LOG(FATAL) << "Unsupported type: " << ty->GetTypeKey(); + throw; + } + } + + std::vector sizes; + for (auto ttype : ttypes) { + sizes.push_back(common::shape_utils::BytesCompactTensor(ttype)); + } + return sizes; +} + +/* std::vector CalcBytesCompactSizes(const Type& type) { std::vector ttypes; std::vector sizes; @@ -376,6 +408,7 @@ std::vector CalcBytesCompactSizes(const Type& type) { } return sizes; } +*/ /*! \brief Dump liveness analysis result statistics. */ void DumpLivenessStat(const MapVSet& live_in) { diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index bc93b1aa..a3054949 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -32,7 +32,7 @@ constexpr float kMegaBytes = 1048576; constexpr float kGigaBytes = 1073741824; // Whether to display verbose logging. -#define SHOW_VERBOSE_LOG 0 +#define SHOW_VERBOSE_LOG 1 // Whether to update tensor index when rematerialization. If defined, then // the rematerialized tensors are less likely to be freed and rematerialized again. @@ -864,8 +864,39 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { ~TensorAnalyzer() { } + /*! + * \brief Get a list of liveness vars for the current let var. This function uses a DFS to handle + * nested tuples. The returned array contains a flat list of vars. The relative order of tuple + * fields is preserved. + */ + tvm::Array GetLivenessVars(const Var& curr_let) { + tvm::Array result_vars; + tvm::Array var_stack; + var_stack.push_back(curr_let); + while (!var_stack.empty()) { + auto v = var_stack.back(); + var_stack.pop_back(); + auto liveness_vars = analyzer_->GetTensorVars(v); + CHECK_GT(liveness_vars.size(), 0U); + if (liveness_vars.size() > 1) { + // If the current let var corresponds to a tuple, the tuple fields should be processed later + for (auto next_v : liveness_vars) + var_stack.push_back(next_v); + } else if (let_var_set_.count(liveness_vars[0])) { + // If this "liveness var" points to a real var rather than an actual liveness var + // it should also be processed later + var_stack.push_back(liveness_vars[0]); + } else { + // Otherwise, add this var to the result + result_vars.push_back(liveness_vars[0]); + } + } + return result_vars; + } + /*! \brief Visit each let statement and return the analyzed information of each tensor. */ TensorInfos Run() { + LOG(INFO) << "Function: " << ir::AsText(func_); // Analyze parameters. for (const auto& var : func_->params) { auto liveness_vars = analyzer_->GetTensorVars(var); @@ -876,26 +907,28 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { const auto& exprs = ell_->exprs; CHECK_EQ(vars.size(), exprs.size()); - VSet let_var_set; for (auto var : vars) { - let_var_set.insert(var); + let_var_set_.insert(var); } size_t n = exprs.size(); for (int i = 0; i < n; ++i) { curr_let_ = vars[i]; + auto liveness_vars = GetLivenessVars(curr_let_); + /* auto liveness_vars = analyzer_->GetTensorVars(curr_let_); // In the case of tuple with may_share, liveness vars may point to the real tensor. for (size_t i = 0; i < liveness_vars.size(); ++i) { auto real_liveness_var = liveness_vars[i]; - while (let_var_set.find(real_liveness_var) != let_var_set.end()) { + while (let_var_set_.find(real_liveness_var) != let_var_set_.end()) { auto cand_liveness_vars = analyzer_->GetTensorVars(real_liveness_var); CHECK_EQ(cand_liveness_vars.size(), 1U); real_liveness_var = cand_liveness_vars[0]; } liveness_vars.Set(i, real_liveness_var); } + */ // Visit the expression to analyze the use count ExprVisitor::VisitExpr(exprs[i]); @@ -945,6 +978,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { TensorInfos tensor_infos_; /*! \brief The profiler used in rematerialization. */ op_profiler::OpProfiler* profiler_; + /*! \brief A set of all let vars in the function. */ + VSet let_var_set_; }; TensorInfos Rematerializer::AnalyzeTensors(const Device& device, const Function& func, From 88de2933fa6544484e5082b375680724b3f3d5b5 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 14 Apr 2022 17:05:16 +0000 Subject: [PATCH 2/5] [Remat][Bugfix] Bugfix. Remove debug log printing. --- src/pass/liveness_analysis.cc | 6 +++--- src/pass/rematerialization.cc | 20 +++----------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/pass/liveness_analysis.cc b/src/pass/liveness_analysis.cc index 3fe8482b..552ceaa3 100644 --- a/src/pass/liveness_analysis.cc +++ b/src/pass/liveness_analysis.cc @@ -359,7 +359,7 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) { * of the fields is preserved. */ std::vector CalcBytesCompactSizes(const Type& type) { - std::vector ty_stack; + tvm::Array ty_stack; std::vector ttypes; ty_stack.push_back(type); while (!ty_stack.empty()) { @@ -367,8 +367,8 @@ std::vector CalcBytesCompactSizes(const Type& type) { ty_stack.pop_back(); if (auto tuple_ty_node = ty.as()) { // If the current type corresponds to a tuple, process the type of each field later - for (auto field : tuple_ty_node->fields) - ty_stack.push_back(field); + for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it ++) + ty_stack.push_back(*it); } else if (auto tensor_ty_node = ty.as()) { // Tensor types are added to the final list and sizes will be calculated ttypes.push_back(tensor_ty_node); diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index a3054949..6a5fc10a 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -32,7 +32,7 @@ constexpr float kMegaBytes = 1048576; constexpr float kGigaBytes = 1073741824; // Whether to display verbose logging. -#define SHOW_VERBOSE_LOG 1 +#define SHOW_VERBOSE_LOG 0 // Whether to update tensor index when rematerialization. If defined, then // the rematerialized tensors are less likely to be freed and rematerialized again. @@ -880,8 +880,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { CHECK_GT(liveness_vars.size(), 0U); if (liveness_vars.size() > 1) { // If the current let var corresponds to a tuple, the tuple fields should be processed later - for (auto next_v : liveness_vars) - var_stack.push_back(next_v); + for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it ++) + var_stack.push_back(*it); } else if (let_var_set_.count(liveness_vars[0])) { // If this "liveness var" points to a real var rather than an actual liveness var // it should also be processed later @@ -915,20 +915,6 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { for (int i = 0; i < n; ++i) { curr_let_ = vars[i]; auto liveness_vars = GetLivenessVars(curr_let_); - /* - auto liveness_vars = analyzer_->GetTensorVars(curr_let_); - - // In the case of tuple with may_share, liveness vars may point to the real tensor. - for (size_t i = 0; i < liveness_vars.size(); ++i) { - auto real_liveness_var = liveness_vars[i]; - while (let_var_set_.find(real_liveness_var) != let_var_set_.end()) { - auto cand_liveness_vars = analyzer_->GetTensorVars(real_liveness_var); - CHECK_EQ(cand_liveness_vars.size(), 1U); - real_liveness_var = cand_liveness_vars[0]; - } - liveness_vars.Set(i, real_liveness_var); - } - */ // Visit the expression to analyze the use count ExprVisitor::VisitExpr(exprs[i]); From f1870f30517f8e405252eae4e7f5d9f0c8d789e6 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 14 Apr 2022 17:15:20 +0000 Subject: [PATCH 3/5] [Remat][Bugfix] Clang format. --- src/pass/liveness_analysis.cc | 8 ++++---- src/pass/rematerialization.cc | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pass/liveness_analysis.cc b/src/pass/liveness_analysis.cc index 552ceaa3..b6032eea 100644 --- a/src/pass/liveness_analysis.cc +++ b/src/pass/liveness_analysis.cc @@ -351,12 +351,12 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) { return VarCreator(this).Run(type); } -/*! +/*! * \brief Calculate the byte compact size of the given type. If the type is a tuple, * then the size of each tensor in the tuple will be returned. Note that size 0 means * a tensor with dynamic shape. If the input type contains nested tuples, the nested * tuples are flattened and a flat vector of sizes is returned at the end. The order - * of the fields is preserved. + * of the fields is preserved. */ std::vector CalcBytesCompactSizes(const Type& type) { tvm::Array ty_stack; @@ -367,7 +367,7 @@ std::vector CalcBytesCompactSizes(const Type& type) { ty_stack.pop_back(); if (auto tuple_ty_node = ty.as()) { // If the current type corresponds to a tuple, process the type of each field later - for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it ++) + for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it++) ty_stack.push_back(*it); } else if (auto tensor_ty_node = ty.as()) { // Tensor types are added to the final list and sizes will be calculated @@ -378,7 +378,7 @@ std::vector CalcBytesCompactSizes(const Type& type) { throw; } } - + std::vector sizes; for (auto ttype : ttypes) { sizes.push_back(common::shape_utils::BytesCompactTensor(ttype)); diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index 6a5fc10a..0be8bcd5 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -864,10 +864,10 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { ~TensorAnalyzer() { } - /*! + /*! * \brief Get a list of liveness vars for the current let var. This function uses a DFS to handle * nested tuples. The returned array contains a flat list of vars. The relative order of tuple - * fields is preserved. + * fields is preserved. */ tvm::Array GetLivenessVars(const Var& curr_let) { tvm::Array result_vars; @@ -880,7 +880,7 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { CHECK_GT(liveness_vars.size(), 0U); if (liveness_vars.size() > 1) { // If the current let var corresponds to a tuple, the tuple fields should be processed later - for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it ++) + for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it++) var_stack.push_back(*it); } else if (let_var_set_.count(liveness_vars[0])) { // If this "liveness var" points to a real var rather than an actual liveness var From e467e69ea8768701d334c751aa8e04a3d2a6616d Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 15 Apr 2022 18:24:45 +0000 Subject: [PATCH 4/5] [Remat][Bugfix] Add test for nested tuples. --- .../pass/test_pass_rematerialization.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/python/pass/test_pass_rematerialization.py b/tests/python/pass/test_pass_rematerialization.py index 3529db78..86ffe35f 100644 --- a/tests/python/pass/test_pass_rematerialization.py +++ b/tests/python/pass/test_pass_rematerialization.py @@ -479,6 +479,47 @@ def expected(): verify_remat(model, args, peak_size - 1, expected(), (peak_size, peak_size - 1)) +def test_nested_tuple(): + """ + A simple test program to check whether the remat pass can handle nested + tuples without crashing. No actual rematerialization is taking place. + """ + device = "cpu" + shape = (16, 16, 64, 64) # 4 MBs + + def get_mod(): + add_op = raf._ffi.op.GetOp("raf.op.add") + relu_op = raf._ffi.op.GetOp("raf.op.relu") + null = raf.ir.const(None) + + # param: 12 MBs + p_0 = raf.ir.var("p0", shape=shape) + p_1 = raf.ir.var("p1", shape=shape) + p_2 = raf.ir.var("p2", shape=shape) + + sb = ScopeBuilder() + # a_1: 4 MBs, total 16 MBs + a_1 = sb.let("a1", relay.Call(add_op, [p_0, p_1, null, null])) + # a_2: 4 MBs, total 20 MBs + a_2 = sb.let("a2", relay.Call(relu_op, [a_1])) + # a_3: 4 MBs, total 24 MBs + a_3 = sb.let("a3", relay.Call(add_op, [a_1, p_2, null, null])) + # Package the three tensors into a nested tuple and return + a_4 = sb.let("a4", relay.Tuple([a_2, a_3])) + a_5 = sb.let("a5", relay.Tuple([a_4, a_1])) + sb.ret(a_5) + func = relay.Function([p_0, p_1, p_2], sb.get()) + return tvm.IRModule.from_expr(func) + + m_p0, _ = randn(shape, device=device) + m_p1, _ = randn(shape, device=device) + m_p2, _ = randn(shape, device=device) + + # Set the memory budget to be higher than the peak + # The IR should remain unchanged after the remat pass + verify_remat(get_mod(), [m_p0, m_p1, m_p2], 28, get_mod()["main"], (24.00, 24.00)) + + def test_reshape(): device = "cpu" shape = (512, 512) # 1 MB. From 28b178eac230eb815dd70d2cf2b3c2381a466c41 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Sat, 16 Apr 2022 00:14:52 +0000 Subject: [PATCH 5/5] [Remat][Bugfix] Remove log printing. --- src/pass/rematerialization.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index 0be8bcd5..ab04a50a 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -896,7 +896,6 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { /*! \brief Visit each let statement and return the analyzed information of each tensor. */ TensorInfos Run() { - LOG(INFO) << "Function: " << ir::AsText(func_); // Analyze parameters. for (const auto& var : func_->params) { auto liveness_vars = analyzer_->GetTensorVars(var);