@@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
68
68
69
69
private def applyInternal (p : LogicalPlan ): LogicalPlan = {
70
70
val inputPlans = p.children
71
+ val commonExprIdSet = p.expressions
72
+ .flatMap(_.collect { case r : CommonExpressionRef => r.id })
73
+ .groupBy(identity)
74
+ .transform((_, v) => v.size)
75
+ .filter(_._2 > 1 )
76
+ .keySet
71
77
val commonExprsPerChild = Array .fill(inputPlans.length)(mutable.ListBuffer .empty[(Alias , Long )])
72
78
var newPlan : LogicalPlan = p.mapExpressions { expr =>
73
- rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
79
+ rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet )
74
80
}
75
81
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
76
82
if (commonExprs.isEmpty) {
@@ -96,16 +102,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
96
102
e : Expression ,
97
103
inputPlans : Seq [LogicalPlan ],
98
104
commonExprsPerChild : Array [mutable.ListBuffer [(Alias , Long )]],
105
+ commonExprIdSet : Set [CommonExpressionId ],
99
106
isNestedWith : Boolean = false ): Expression = {
100
107
if (! e.containsPattern(WITH_EXPRESSION )) return e
101
108
e match {
102
109
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
103
110
case w : With if ! isNestedWith =>
104
111
// Rewrite nested With expressions first
105
112
val child = rewriteWithExprAndInputPlans(
106
- w.child, inputPlans, commonExprsPerChild, isNestedWith = true )
113
+ w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true )
107
114
val defs = w.defs.map(rewriteWithExprAndInputPlans(
108
- _, inputPlans, commonExprsPerChild, isNestedWith = true ))
115
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true ))
109
116
val refToExpr = mutable.HashMap .empty[CommonExpressionId , Expression ]
110
117
111
118
defs.zipWithIndex.foreach { case (CommonExpressionDef (child, id), index) =>
@@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
114
121
" Cannot rewrite canonicalized Common expression definitions" )
115
122
}
116
123
117
- if (CollapseProject .isCheap(child)) {
124
+ if (CollapseProject .isCheap(child) || ! commonExprIdSet.contains(id) ) {
118
125
refToExpr(id) = child
119
126
} else {
120
127
val childPlanIndex = inputPlans.indexWhere(
@@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
171
178
172
179
case c : ConditionalExpression =>
173
180
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
174
- rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
181
+ rewriteWithExprAndInputPlans(
182
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
175
183
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
176
184
// Use transformUp to handle nested With.
177
185
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION )) {
@@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
185
193
}
186
194
187
195
case other => other.mapChildren(
188
- rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
196
+ rewriteWithExprAndInputPlans(
197
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
189
198
)
190
199
}
191
200
}
0 commit comments