diff --git a/src/pass/liveness_analysis.cc b/src/pass/liveness_analysis.cc index 18badbef..b6032eea 100644 --- a/src/pass/liveness_analysis.cc +++ b/src/pass/liveness_analysis.cc @@ -351,10 +351,42 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) { return VarCreator(this).Run(type); } -/*! \brief Calculate the byte compact size of the given type. If the type is a tuple, +/*! + * \brief Calculate the byte compact size of the given type. If the type is a tuple, * then the size of each tensor in the tuple will be returned. Note that size 0 means - * a tensor with dynamic shape. + * a tensor with dynamic shape. If the input type contains nested tuples, the nested + * tuples are flattened and a flat vector of sizes is returned at the end. The order + * of the fields is preserved. */ +std::vector CalcBytesCompactSizes(const Type& type) { + tvm::Array ty_stack; + std::vector ttypes; + ty_stack.push_back(type); + while (!ty_stack.empty()) { + auto ty = ty_stack.back(); + ty_stack.pop_back(); + if (auto tuple_ty_node = ty.as()) { + // If the current type corresponds to a tuple, process the type of each field later + for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it++) + ty_stack.push_back(*it); + } else if (auto tensor_ty_node = ty.as()) { + // Tensor types are added to the final list and sizes will be calculated + ttypes.push_back(tensor_ty_node); + } else { + // Other types are not supported + LOG(FATAL) << "Unsupported type: " << ty->GetTypeKey(); + throw; + } + } + + std::vector sizes; + for (auto ttype : ttypes) { + sizes.push_back(common::shape_utils::BytesCompactTensor(ttype)); + } + return sizes; +} + +/* std::vector CalcBytesCompactSizes(const Type& type) { std::vector ttypes; std::vector sizes; @@ -376,6 +408,7 @@ std::vector CalcBytesCompactSizes(const Type& type) { } return sizes; } +*/ /*! \brief Dump liveness analysis result statistics. */ void DumpLivenessStat(const MapVSet& live_in) { diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index bc93b1aa..ab04a50a 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -864,6 +864,36 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { ~TensorAnalyzer() { } + /*! + * \brief Get a list of liveness vars for the current let var. This function uses a DFS to handle + * nested tuples. The returned array contains a flat list of vars. The relative order of tuple + * fields is preserved. + */ + tvm::Array GetLivenessVars(const Var& curr_let) { + tvm::Array result_vars; + tvm::Array var_stack; + var_stack.push_back(curr_let); + while (!var_stack.empty()) { + auto v = var_stack.back(); + var_stack.pop_back(); + auto liveness_vars = analyzer_->GetTensorVars(v); + CHECK_GT(liveness_vars.size(), 0U); + if (liveness_vars.size() > 1) { + // If the current let var corresponds to a tuple, the tuple fields should be processed later + for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it++) + var_stack.push_back(*it); + } else if (let_var_set_.count(liveness_vars[0])) { + // If this "liveness var" points to a real var rather than an actual liveness var + // it should also be processed later + var_stack.push_back(liveness_vars[0]); + } else { + // Otherwise, add this var to the result + result_vars.push_back(liveness_vars[0]); + } + } + return result_vars; + } + /*! \brief Visit each let statement and return the analyzed information of each tensor. */ TensorInfos Run() { // Analyze parameters. @@ -876,26 +906,14 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { const auto& exprs = ell_->exprs; CHECK_EQ(vars.size(), exprs.size()); - VSet let_var_set; for (auto var : vars) { - let_var_set.insert(var); + let_var_set_.insert(var); } size_t n = exprs.size(); for (int i = 0; i < n; ++i) { curr_let_ = vars[i]; - auto liveness_vars = analyzer_->GetTensorVars(curr_let_); - - // In the case of tuple with may_share, liveness vars may point to the real tensor. - for (size_t i = 0; i < liveness_vars.size(); ++i) { - auto real_liveness_var = liveness_vars[i]; - while (let_var_set.find(real_liveness_var) != let_var_set.end()) { - auto cand_liveness_vars = analyzer_->GetTensorVars(real_liveness_var); - CHECK_EQ(cand_liveness_vars.size(), 1U); - real_liveness_var = cand_liveness_vars[0]; - } - liveness_vars.Set(i, real_liveness_var); - } + auto liveness_vars = GetLivenessVars(curr_let_); // Visit the expression to analyze the use count ExprVisitor::VisitExpr(exprs[i]); @@ -945,6 +963,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { TensorInfos tensor_infos_; /*! \brief The profiler used in rematerialization. */ op_profiler::OpProfiler* profiler_; + /*! \brief A set of all let vars in the function. */ + VSet let_var_set_; }; TensorInfos Rematerializer::AnalyzeTensors(const Device& device, const Function& func, diff --git a/tests/python/pass/test_pass_rematerialization.py b/tests/python/pass/test_pass_rematerialization.py index 3529db78..86ffe35f 100644 --- a/tests/python/pass/test_pass_rematerialization.py +++ b/tests/python/pass/test_pass_rematerialization.py @@ -479,6 +479,47 @@ def expected(): verify_remat(model, args, peak_size - 1, expected(), (peak_size, peak_size - 1)) +def test_nested_tuple(): + """ + A simple test program to check whether the remat pass can handle nested + tuples without crashing. No actual rematerialization is taking place. + """ + device = "cpu" + shape = (16, 16, 64, 64) # 4 MBs + + def get_mod(): + add_op = raf._ffi.op.GetOp("raf.op.add") + relu_op = raf._ffi.op.GetOp("raf.op.relu") + null = raf.ir.const(None) + + # param: 12 MBs + p_0 = raf.ir.var("p0", shape=shape) + p_1 = raf.ir.var("p1", shape=shape) + p_2 = raf.ir.var("p2", shape=shape) + + sb = ScopeBuilder() + # a_1: 4 MBs, total 16 MBs + a_1 = sb.let("a1", relay.Call(add_op, [p_0, p_1, null, null])) + # a_2: 4 MBs, total 20 MBs + a_2 = sb.let("a2", relay.Call(relu_op, [a_1])) + # a_3: 4 MBs, total 24 MBs + a_3 = sb.let("a3", relay.Call(add_op, [a_1, p_2, null, null])) + # Package the three tensors into a nested tuple and return + a_4 = sb.let("a4", relay.Tuple([a_2, a_3])) + a_5 = sb.let("a5", relay.Tuple([a_4, a_1])) + sb.ret(a_5) + func = relay.Function([p_0, p_1, p_2], sb.get()) + return tvm.IRModule.from_expr(func) + + m_p0, _ = randn(shape, device=device) + m_p1, _ = randn(shape, device=device) + m_p2, _ = randn(shape, device=device) + + # Set the memory budget to be higher than the peak + # The IR should remain unchanged after the remat pass + verify_remat(get_mod(), [m_p0, m_p1, m_p2], 28, get_mod()["main"], (24.00, 24.00)) + + def test_reshape(): device = "cpu" shape = (512, 512) # 1 MB.