Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Remat][Bugfix] Enhance the rematerialization pass to handle nested tuples #17

Merged
merged 5 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/pass/liveness_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,42 @@ Var LivenessAnalyzer::CreateTensorVar(const Type& type) {
return VarCreator(this).Run(type);
}

/*! \brief Calculate the byte compact size of the given type. If the type is a tuple,
/*!
* \brief Calculate the byte compact size of the given type. If the type is a tuple,
* then the size of each tensor in the tuple will be returned. Note that size 0 means
* a tensor with dynamic shape.
* a tensor with dynamic shape. If the input type contains nested tuples, the nested
* tuples are flattened and a flat vector of sizes is returned at the end. The order
* of the fields is preserved.
*/
std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
zhouyuan1119 marked this conversation as resolved.
Show resolved Hide resolved
tvm::Array<Type> ty_stack;
std::vector<const TensorTypeNode*> ttypes;
ty_stack.push_back(type);
while (!ty_stack.empty()) {
auto ty = ty_stack.back();
ty_stack.pop_back();
if (auto tuple_ty_node = ty.as<TupleTypeNode>()) {
// If the current type corresponds to a tuple, process the type of each field later
for (auto it = tuple_ty_node->fields.rbegin(); it != tuple_ty_node->fields.rend(); it++)
ty_stack.push_back(*it);
} else if (auto tensor_ty_node = ty.as<TensorTypeNode>()) {
// Tensor types are added to the final list and sizes will be calculated
ttypes.push_back(tensor_ty_node);
} else {
// Other types are not supported
LOG(FATAL) << "Unsupported type: " << ty->GetTypeKey();
throw;
}
}

std::vector<int64_t> sizes;
for (auto ttype : ttypes) {
sizes.push_back(common::shape_utils::BytesCompactTensor(ttype));
}
return sizes;
}

/*
std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
std::vector<const TensorTypeNode*> ttypes;
std::vector<int64_t> sizes;
Expand All @@ -376,6 +408,7 @@ std::vector<int64_t> CalcBytesCompactSizes(const Type& type) {
}
return sizes;
}
*/

/*! \brief Dump liveness analysis result statistics. */
void DumpLivenessStat(const MapVSet& live_in) {
Expand Down
49 changes: 35 additions & 14 deletions src/pass/rematerialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,39 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
~TensorAnalyzer() {
}

/*!
* \brief Get a list of liveness vars for the current let var. This function uses a DFS to handle
* nested tuples. The returned array contains a flat list of vars. The relative order of tuple
* fields is preserved.
*/
tvm::Array<Var> GetLivenessVars(const Var& curr_let) {
tvm::Array<Var> result_vars;
tvm::Array<Var> var_stack;
var_stack.push_back(curr_let);
while (!var_stack.empty()) {
auto v = var_stack.back();
var_stack.pop_back();
auto liveness_vars = analyzer_->GetTensorVars(v);
CHECK_GT(liveness_vars.size(), 0U);
if (liveness_vars.size() > 1) {
// If the current let var corresponds to a tuple, the tuple fields should be processed later
for (auto it = liveness_vars.rbegin(); it != liveness_vars.rend(); it++)
var_stack.push_back(*it);
} else if (let_var_set_.count(liveness_vars[0])) {
// If this "liveness var" points to a real var rather than an actual liveness var
// it should also be processed later
var_stack.push_back(liveness_vars[0]);
} else {
// Otherwise, add this var to the result
result_vars.push_back(liveness_vars[0]);
}
}
return result_vars;
}

/*! \brief Visit each let statement and return the analyzed information of each tensor. */
TensorInfos Run() {
LOG(INFO) << "Function: " << ir::AsText(func_);
zhouyuan1119 marked this conversation as resolved.
Show resolved Hide resolved
// Analyze parameters.
for (const auto& var : func_->params) {
auto liveness_vars = analyzer_->GetTensorVars(var);
Expand All @@ -876,26 +907,14 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
const auto& exprs = ell_->exprs;
CHECK_EQ(vars.size(), exprs.size());

VSet let_var_set;
for (auto var : vars) {
let_var_set.insert(var);
let_var_set_.insert(var);
}

size_t n = exprs.size();
for (int i = 0; i < n; ++i) {
curr_let_ = vars[i];
auto liveness_vars = analyzer_->GetTensorVars(curr_let_);

// In the case of tuple with may_share, liveness vars may point to the real tensor.
for (size_t i = 0; i < liveness_vars.size(); ++i) {
auto real_liveness_var = liveness_vars[i];
while (let_var_set.find(real_liveness_var) != let_var_set.end()) {
auto cand_liveness_vars = analyzer_->GetTensorVars(real_liveness_var);
CHECK_EQ(cand_liveness_vars.size(), 1U);
real_liveness_var = cand_liveness_vars[0];
}
liveness_vars.Set(i, real_liveness_var);
}
auto liveness_vars = GetLivenessVars(curr_let_);

// Visit the expression to analyze the use count
ExprVisitor::VisitExpr(exprs[i]);
Expand Down Expand Up @@ -945,6 +964,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor {
TensorInfos tensor_infos_;
/*! \brief The profiler used in rematerialization. */
op_profiler::OpProfiler* profiler_;
/*! \brief A set of all let vars in the function. */
VSet let_var_set_;
};

TensorInfos Rematerializer::AnalyzeTensors(const Device& device, const Function& func,
Expand Down