Skip to content

Commit

Permalink
[Remat][Bugfix] Enhance the rematerialization pass to handle nested t…
Browse files Browse the repository at this point in the history
…uples (#17)
  • Loading branch information
zhouyuan1119 authored Apr 17, 2022
1 parent 13e787c commit 3f51010
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 16 deletions.
37 changes: 35 additions & 2 deletions src/pass/liveness_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,42 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) {
return VarCreator(this).Run(type);
}

/*! \brief Calculate the byte compact size of the given type. If the type is a tuple,
/*!
* \brief Calculate the byte compact size of the given type. If the type is a tuple,
* then the size of each tensor in the tuple will be returned. Note that size 0 means
* a tensor with dynamic shape.
* a tensor with dynamic shape. If the input type contains nested tuples, the nested
* tuples are flattened and a flat vector of sizes is returned at the end. The order
* of the fields is preserved.
*/
std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
tvm::Array<Type> ty_stack;
std::vector<const TensorTypeNode*> ttypes;
ty_stack.push_back(type);
while (!ty_stack.empty()) {
auto ty = ty_stack.back();
ty_stack.pop_back();
if (auto tuple_ty_node = ty.as<TupleTypeNode>()) {
// If the current type corresponds to a tuple, process the type of each field later
for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it++)
ty_stack.push_back(*it);
} else if (auto tensor_ty_node = ty.as<TensorTypeNode>()) {
// Tensor types are added to the final list and sizes will be calculated
ttypes.push_back(tensor_ty_node);
} else {
// Other types are not supported
LOG(FATAL) << "Unsupported type: " << ty->GetTypeKey();
throw;
}
}

std::vector<int64_t> sizes;
for (auto ttype : ttypes) {
sizes.push_back(common::shape_utils::BytesCompactTensor(ttype));
}
return sizes;
}

/*
std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
std::vector<const TensorTypeNode*> ttypes;
std::vector<int64_t> sizes;
Expand All @@ -376,6 +408,7 @@ std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
}
return sizes;
}
*/

/*! \brief Dump liveness analysis result statistics. */
void DumpLivenessStat(const MapVSet& live_in) {
Expand Down
48 changes: 34 additions & 14 deletions src/pass/rematerialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,36 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
~TensorAnalyzer() {
}

/*!
* \brief Get a list of liveness vars for the current let var. This function uses a DFS to handle
* nested tuples. The returned array contains a flat list of vars. The relative order of tuple
* fields is preserved.
*/
tvm::Array<Var> GetLivenessVars(const Var& curr_let) {
tvm::Array<Var> result_vars;
tvm::Array<Var> var_stack;
var_stack.push_back(curr_let);
while (!var_stack.empty()) {
auto v = var_stack.back();
var_stack.pop_back();
auto liveness_vars = analyzer_->GetTensorVars(v);
CHECK_GT(liveness_vars.size(), 0U);
if (liveness_vars.size() > 1) {
// If the current let var corresponds to a tuple, the tuple fields should be processed later
for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it++)
var_stack.push_back(*it);
} else if (let_var_set_.count(liveness_vars[0])) {
// If this "liveness var" points to a real var rather than an actual liveness var
// it should also be processed later
var_stack.push_back(liveness_vars[0]);
} else {
// Otherwise, add this var to the result
result_vars.push_back(liveness_vars[0]);
}
}
return result_vars;
}

/*! \brief Visit each let statement and return the analyzed information of each tensor. */
TensorInfos Run() {
// Analyze parameters.
Expand All @@ -876,26 +906,14 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
const auto& exprs = ell_->exprs;
CHECK_EQ(vars.size(), exprs.size());

VSet let_var_set;
for (auto var : vars) {
let_var_set.insert(var);
let_var_set_.insert(var);
}

size_t n = exprs.size();
for (int i = 0; i < n; ++i) {
curr_let_ = vars[i];
auto liveness_vars = analyzer_->GetTensorVars(curr_let_);

// In the case of tuple with may_share, liveness vars may point to the real tensor.
for (size_t i = 0; i < liveness_vars.size(); ++i) {
auto real_liveness_var = liveness_vars[i];
while (let_var_set.find(real_liveness_var) != let_var_set.end()) {
auto cand_liveness_vars = analyzer_->GetTensorVars(real_liveness_var);
CHECK_EQ(cand_liveness_vars.size(), 1U);
real_liveness_var = cand_liveness_vars[0];
}
liveness_vars.Set(i, real_liveness_var);
}
auto liveness_vars = GetLivenessVars(curr_let_);

// Visit the expression to analyze the use count
ExprVisitor::VisitExpr(exprs[i]);
Expand Down Expand Up @@ -945,6 +963,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
TensorInfos tensor_infos_;
/*! \brief The profiler used in rematerialization. */
op_profiler::OpProfiler* profiler_;
/*! \brief A set of all let vars in the function. */
VSet let_var_set_;
};

TensorInfos Rematerializer::AnalyzeTensors(const Device& device, const Function& func,
Expand Down
41 changes: 41 additions & 0 deletions tests/python/pass/test_pass_rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,47 @@ def expected():
verify_remat(model, args, peak_size - 1, expected(), (peak_size, peak_size - 1))


def test_nested_tuple():
"""
A simple test program to check whether the remat pass can handle nested
tuples without crashing. No actual rematerialization is taking place.
"""
device = "cpu"
shape = (16, 16, 64, 64) # 4 MBs

def get_mod():
add_op = raf._ffi.op.GetOp("raf.op.add")
relu_op = raf._ffi.op.GetOp("raf.op.relu")
null = raf.ir.const(None)

# param: 12 MBs
p_0 = raf.ir.var("p0", shape=shape)
p_1 = raf.ir.var("p1", shape=shape)
p_2 = raf.ir.var("p2", shape=shape)

sb = ScopeBuilder()
# a_1: 4 MBs, total 16 MBs
a_1 = sb.let("a1", relay.Call(add_op, [p_0, p_1, null, null]))
# a_2: 4 MBs, total 20 MBs
a_2 = sb.let("a2", relay.Call(relu_op, [a_1]))
# a_3: 4 MBs, total 24 MBs
a_3 = sb.let("a3", relay.Call(add_op, [a_1, p_2, null, null]))
# Package the three tensors into a nested tuple and return
a_4 = sb.let("a4", relay.Tuple([a_2, a_3]))
a_5 = sb.let("a5", relay.Tuple([a_4, a_1]))
sb.ret(a_5)
func = relay.Function([p_0, p_1, p_2], sb.get())
return tvm.IRModule.from_expr(func)

m_p0, _ = randn(shape, device=device)
m_p1, _ = randn(shape, device=device)
m_p2, _ = randn(shape, device=device)

# Set the memory budget to be higher than the peak
# The IR should remain unchanged after the remat pass
verify_remat(get_mod(), [m_p0, m_p1, m_p2], 28, get_mod()["main"], (24.00, 24.00))


def test_reshape():
device = "cpu"
shape = (512, 512) # 1 MB.
Expand Down

0 comments on commit 3f51010

Please sign in to comment.