Skip to content

Commit e30ac71

Browse files
authored
[Arith][TIR] IntSetAnalyzer, delay intersection of IntSet until use (#12821)
Follow-up from apache/tvm#11970, to improve performance. In the initial implementation, the `analyzer->int_set` would compute the intersection of all scope-based constraints when entering the scope, even if they weren't actually used. This commit delays the call to `Intersect` until required, following the same behavior as `ConstIntBound`.
1 parent 2af9b90 commit e30ac71

File tree

1 file changed

+52
-74
lines changed

1 file changed

+52
-74
lines changed

src/arith/int_set.cc

Lines changed: 52 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,13 @@ using namespace tir;
362362
// We might use better set analysis in the future to replace the intervalset.
363363
class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
364364
public:
365-
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false)
366-
: analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
365+
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
366+
const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr,
367+
bool eval_vec = false)
368+
: analyzer_(analyzer),
369+
dom_map_(dom_map),
370+
dom_constraints_(dom_constraints),
371+
eval_vec_(eval_vec) {}
367372

368373
IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
369374
// evaluate and relax the set
@@ -383,18 +388,40 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
383388

384389
IntervalSet VisitExpr_(const VarNode* op) final {
385390
Var var = GetRef<Var>(op);
391+
392+
Array<IntSet> values;
393+
if (dom_constraints_) {
394+
for (const auto& constraint : *dom_constraints_) {
395+
if (var.same_as(constraint.first)) {
396+
values.push_back(constraint.second);
397+
}
398+
}
399+
}
400+
386401
auto it = dom_map_.find(var);
387402
if (it != dom_map_.end()) {
388-
IntervalSet res = ToIntervalSet((*it).second);
389-
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
390-
return res;
391-
}
392-
// recursively evaluate mapped result
393-
// in case the domain contains variables to be relaxed.
394-
return Eval(res);
395-
} else {
403+
values.push_back((*it).second);
404+
}
405+
406+
if (values.empty()) {
396407
return IntervalSet::SinglePoint(var);
397408
}
409+
410+
IntSet intersection = [&]() {
411+
if (values.size() == 1) {
412+
return values.front();
413+
} else {
414+
return Intersect(values);
415+
}
416+
}();
417+
418+
IntervalSet res = ToIntervalSet(intersection);
419+
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
420+
return res;
421+
}
422+
// recursively evaluate mapped result
423+
// in case the domain contains variables to be relaxed.
424+
return Eval(res);
398425
}
399426

400427
IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
@@ -517,6 +544,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
517544
// analyzer
518545
Analyzer* analyzer_;
519546
const Map<Var, IntSet>& dom_map_;
547+
const std::vector<std::pair<Var, IntSet>>* dom_constraints_;
520548
bool eval_vec_{false};
521549
};
522550

@@ -529,7 +557,7 @@ class IntSetAnalyzer::Impl {
529557
}
530558

531559
IntSet Eval(const PrimExpr& expr) const {
532-
return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr);
560+
return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr);
533561
}
534562

535563
void Bind(const Var& var, const Range& range, bool allow_override) {
@@ -541,10 +569,6 @@ class IntSetAnalyzer::Impl {
541569
std::function<void()> EnterConstraint(const PrimExpr& constraint);
542570

543571
private:
544-
// Get the current variable bounds, including both global bounds and
545-
// scope-dependent bounds.
546-
Map<Var, IntSet> GetCurrentBounds() const;
547-
548572
// Utility function to split a boolean condition into the domain
549573
// bounds implied by that condition.
550574
static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond);
@@ -556,9 +580,11 @@ class IntSetAnalyzer::Impl {
556580
// ranges)
557581
Map<Var, IntSet> dom_map_;
558582

559-
// Map of variables to implicit scope-dependent bounds (e.g. inside
560-
// the body of an if-statement)
561-
Map<Var, IntSet> constraints_;
583+
// List of implicit scope-dependent bounds (e.g. inside the body of
584+
// an if-statement). Maintained as a list of constraints, rather
585+
// than as a `Map<Var,IntSet>`, to avoid computing an Intersection
586+
// until required.
587+
std::vector<std::pair<Var, IntSet>> dom_constraints_;
562588
};
563589

564590
IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
@@ -603,29 +629,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
603629
Update(var, Eval(expr), can_override);
604630
}
605631

606-
Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
607-
// If either constraints_ or dom_map_ is empty, return the other to
608-
// avoid constructing a new map.
609-
if (constraints_.empty()) {
610-
return dom_map_;
611-
} else if (dom_map_.empty()) {
612-
return constraints_;
613-
}
614-
615-
// If neither is empty, construct a merged domain map with
616-
// information from both sources.
617-
Map<Var, IntSet> merged = dom_map_;
618-
for (const auto& pair : constraints_) {
619-
auto it = merged.find(pair.first);
620-
if (it == merged.end()) {
621-
merged.Set(pair.first, pair.second);
622-
} else {
623-
merged.Set(pair.first, Intersect({pair.second, (*it).second}));
624-
}
625-
}
626-
return merged;
627-
}
628-
629632
std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
630633
const PrimExpr& constraint) {
631634
PVar<Var> x;
@@ -665,41 +668,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
665668
}
666669

667670
std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) {
668-
Map<Var, IntSet> cached_values;
669-
670671
auto bounds = DetectBoundInfo(constraint);
671672

672673
if (bounds.size() == 0) return nullptr;
673674

674-
// Collect the current values of each var that is changes by this
675-
// constraint.
676-
for (const auto& pair : bounds) {
677-
auto it = constraints_.find(pair.first);
678-
if (it == constraints_.end()) {
679-
cached_values.Set(pair.first, IntSet());
680-
} else {
681-
cached_values.Set(pair.first, (*it).second);
682-
}
683-
}
684-
685-
// Update all constraints
686-
for (const auto& pair : bounds) {
687-
auto it = constraints_.find(pair.first);
688-
if (it == constraints_.end()) {
689-
constraints_.Set(pair.first, pair.second);
690-
} else {
691-
constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
692-
}
693-
}
694-
695-
auto frecover = [cached_values, this]() {
696-
for (const auto& it : cached_values) {
697-
if (it.second.defined()) {
698-
constraints_.Set(it.first, it.second);
699-
} else {
700-
constraints_.erase(it.first);
701-
}
702-
}
675+
size_t old_size = dom_constraints_.size();
676+
dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end());
677+
size_t new_size = dom_constraints_.size();
678+
auto frecover = [old_size, new_size, this]() {
679+
ICHECK_EQ(dom_constraints_.size(), new_size);
680+
dom_constraints_.resize(old_size);
703681
};
704682
return frecover;
705683
}
@@ -960,13 +938,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
960938

961939
IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
962940
Analyzer ana;
963-
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
941+
return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e);
964942
}
965943

966944
IntSet IntSet::Vector(PrimExpr x) {
967945
Analyzer ana;
968946
Map<Var, IntSet> dmap;
969-
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
947+
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
970948
}
971949

972950
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {

0 commit comments

Comments
 (0)