@@ -362,8 +362,13 @@ using namespace tir;
362362// We might use better set analysis in the future to replace the intervalset.
363363class IntervalSetEvaluator : public ExprFunctor <IntervalSet(const PrimExpr&)> {
364364 public:
365- IntervalSetEvaluator (Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false )
366- : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
365+ IntervalSetEvaluator (Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
366+ const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr ,
367+ bool eval_vec = false )
368+ : analyzer_(analyzer),
369+ dom_map_ (dom_map),
370+ dom_constraints_(dom_constraints),
371+ eval_vec_(eval_vec) {}
367372
368373 IntervalSet Eval (const PrimExpr& val) { return this ->VisitExpr (val); }
369374 // evaluate and relax the set
@@ -383,18 +388,40 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
383388
384389 IntervalSet VisitExpr_ (const VarNode* op) final {
385390 Var var = GetRef<Var>(op);
391+
392+ Array<IntSet> values;
393+ if (dom_constraints_) {
394+ for (const auto & constraint : *dom_constraints_) {
395+ if (var.same_as (constraint.first )) {
396+ values.push_back (constraint.second );
397+ }
398+ }
399+ }
400+
386401 auto it = dom_map_.find (var);
387402 if (it != dom_map_.end ()) {
388- IntervalSet res = ToIntervalSet ((*it).second );
389- if (res->min_value .same_as (var) && res->max_value .same_as (var)) {
390- return res;
391- }
392- // recursively evaluate mapped result
393- // in case the domain contains variables to be relaxed.
394- return Eval (res);
395- } else {
403+ values.push_back ((*it).second );
404+ }
405+
406+ if (values.empty ()) {
396407 return IntervalSet::SinglePoint (var);
397408 }
409+
410+ IntSet intersection = [&]() {
411+ if (values.size () == 1 ) {
412+ return values.front ();
413+ } else {
414+ return Intersect (values);
415+ }
416+ }();
417+
418+ IntervalSet res = ToIntervalSet (intersection);
419+ if (res->min_value .same_as (var) && res->max_value .same_as (var)) {
420+ return res;
421+ }
422+ // recursively evaluate mapped result
423+ // in case the domain contains variables to be relaxed.
424+ return Eval (res);
398425 }
399426
400427 IntervalSet VisitExpr_ (const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
@@ -517,6 +544,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
517544 // analyzer
518545 Analyzer* analyzer_;
519546 const Map<Var, IntSet>& dom_map_;
547+ const std::vector<std::pair<Var, IntSet>>* dom_constraints_;
520548 bool eval_vec_{false };
521549};
522550
@@ -529,7 +557,7 @@ class IntSetAnalyzer::Impl {
529557 }
530558
531559 IntSet Eval (const PrimExpr& expr) const {
532- return IntervalSetEvaluator (analyzer_, GetCurrentBounds () , true ).Eval (expr);
560+ return IntervalSetEvaluator (analyzer_, dom_map_, &dom_constraints_ , true ).Eval (expr);
533561 }
534562
535563 void Bind (const Var& var, const Range& range, bool allow_override) {
@@ -541,10 +569,6 @@ class IntSetAnalyzer::Impl {
541569 std::function<void ()> EnterConstraint (const PrimExpr& constraint);
542570
543571 private:
544- // Get the current variable bounds, including both global bounds and
545- // scope-dependent bounds.
546- Map<Var, IntSet> GetCurrentBounds () const ;
547-
548572 // Utility function to split a boolean condition into the domain
549573 // bounds implied by that condition.
550574 static std::vector<std::pair<Var, IntSet>> DetectBoundInfo (const PrimExpr& cond);
@@ -556,9 +580,11 @@ class IntSetAnalyzer::Impl {
556580 // ranges)
557581 Map<Var, IntSet> dom_map_;
558582
559- // Map of variables to implicit scope-dependent bounds (e.g. inside
560- // the body of an if-statement)
561- Map<Var, IntSet> constraints_;
583+ // List of implicit scope-dependent bounds (e.g. inside the body of
584+ // an if-statement). Maintained as a list of constraints, rather
585+ // than as a `Map<Var,IntSet>`, to avoid computing an Intersection
586+ // until required.
587+ std::vector<std::pair<Var, IntSet>> dom_constraints_;
562588};
563589
564590IntSetAnalyzer::IntSetAnalyzer (Analyzer* parent) : impl_(new Impl(parent)) {}
@@ -603,29 +629,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
603629 Update (var, Eval (expr), can_override);
604630}
605631
606- Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds () const {
607- // If either constraints_ or dom_map_ is empty, return the other to
608- // avoid constructing a new map.
609- if (constraints_.empty ()) {
610- return dom_map_;
611- } else if (dom_map_.empty ()) {
612- return constraints_;
613- }
614-
615- // If neither is empty, construct a merged domain map with
616- // information from both sources.
617- Map<Var, IntSet> merged = dom_map_;
618- for (const auto & pair : constraints_) {
619- auto it = merged.find (pair.first );
620- if (it == merged.end ()) {
621- merged.Set (pair.first , pair.second );
622- } else {
623- merged.Set (pair.first , Intersect ({pair.second , (*it).second }));
624- }
625- }
626- return merged;
627- }
628-
629632std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo (
630633 const PrimExpr& constraint) {
631634 PVar<Var> x;
@@ -665,41 +668,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
665668}
666669
667670std::function<void ()> IntSetAnalyzer::Impl::EnterConstraint (const PrimExpr& constraint) {
668- Map<Var, IntSet> cached_values;
669-
670671 auto bounds = DetectBoundInfo (constraint);
671672
672673 if (bounds.size () == 0 ) return nullptr ;
673674
674- // Collect the current values of each var that is changes by this
675- // constraint.
676- for (const auto & pair : bounds) {
677- auto it = constraints_.find (pair.first );
678- if (it == constraints_.end ()) {
679- cached_values.Set (pair.first , IntSet ());
680- } else {
681- cached_values.Set (pair.first , (*it).second );
682- }
683- }
684-
685- // Update all constraints
686- for (const auto & pair : bounds) {
687- auto it = constraints_.find (pair.first );
688- if (it == constraints_.end ()) {
689- constraints_.Set (pair.first , pair.second );
690- } else {
691- constraints_.Set (pair.first , Intersect ({pair.second , (*it).second }));
692- }
693- }
694-
695- auto frecover = [cached_values, this ]() {
696- for (const auto & it : cached_values) {
697- if (it.second .defined ()) {
698- constraints_.Set (it.first , it.second );
699- } else {
700- constraints_.erase (it.first );
701- }
702- }
675+ size_t old_size = dom_constraints_.size ();
676+ dom_constraints_.insert (dom_constraints_.end (), bounds.begin (), bounds.end ());
677+ size_t new_size = dom_constraints_.size ();
678+ auto frecover = [old_size, new_size, this ]() {
679+ ICHECK_EQ (dom_constraints_.size (), new_size);
680+ dom_constraints_.resize (old_size);
703681 };
704682 return frecover;
705683}
@@ -960,13 +938,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
960938
961939IntSet EvalSet (PrimExpr e, const Map<Var, IntSet>& dom_map) {
962940 Analyzer ana;
963- return IntervalSetEvaluator (&ana, dom_map, false ).Eval (e);
941+ return IntervalSetEvaluator (&ana, dom_map, {}, false ).Eval (e);
964942}
965943
966944IntSet IntSet::Vector (PrimExpr x) {
967945 Analyzer ana;
968946 Map<Var, IntSet> dmap;
969- return IntervalSetEvaluator (&ana, dmap, true ).Eval (x);
947+ return IntervalSetEvaluator (&ana, dmap, {}, true ).Eval (x);
970948}
971949
972950IntSet EvalSet (PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
0 commit comments