Skip to content

Commit 492fcd8

Browse files
zml1206cloud-fan
authored andcommitted
[SPARK-50683][SQL] Inline the common expression in With if used once
### What changes were proposed in this pull request? As title. ### Why are the changes needed? Simplify plan and reduce unnecessary project. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49310 from zml1206/with. Authored-by: zml1206 <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 721a417 commit 492fcd8

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala

+15-6
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
6868

6969
private def applyInternal(p: LogicalPlan): LogicalPlan = {
7070
val inputPlans = p.children
71+
val commonExprIdSet = p.expressions
72+
.flatMap(_.collect { case r: CommonExpressionRef => r.id })
73+
.groupBy(identity)
74+
.transform((_, v) => v.size)
75+
.filter(_._2 > 1)
76+
.keySet
7177
val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
7278
var newPlan: LogicalPlan = p.mapExpressions { expr =>
73-
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
79+
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet)
7480
}
7581
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
7682
if (commonExprs.isEmpty) {
@@ -96,16 +102,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
96102
e: Expression,
97103
inputPlans: Seq[LogicalPlan],
98104
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
105+
commonExprIdSet: Set[CommonExpressionId],
99106
isNestedWith: Boolean = false): Expression = {
100107
if (!e.containsPattern(WITH_EXPRESSION)) return e
101108
e match {
102109
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
103110
case w: With if !isNestedWith =>
104111
// Rewrite nested With expressions first
105112
val child = rewriteWithExprAndInputPlans(
106-
w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
113+
w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true)
107114
val defs = w.defs.map(rewriteWithExprAndInputPlans(
108-
_, inputPlans, commonExprsPerChild, isNestedWith = true))
115+
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
109116
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
110117

111118
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
@@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
114121
"Cannot rewrite canonicalized Common expression definitions")
115122
}
116123

117-
if (CollapseProject.isCheap(child)) {
124+
if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) {
118125
refToExpr(id) = child
119126
} else {
120127
val childPlanIndex = inputPlans.indexWhere(
@@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
171178

172179
case c: ConditionalExpression =>
173180
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
174-
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
181+
rewriteWithExprAndInputPlans(
182+
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
175183
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
176184
// Use transformUp to handle nested With.
177185
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
185193
}
186194

187195
case other => other.mapChildren(
188-
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
196+
rewriteWithExprAndInputPlans(
197+
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
189198
)
190199
}
191200
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
140140
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
141141
val ref2 = new CommonExpressionRef(commonExprDef2)
142142
// The inner main expression references the outer expression
143-
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
143+
val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
144144
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
145145
comparePlans(
146146
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
@@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
152152
.select(star(), (a + a).as("_common_expr_2"))
153153
// The final Project contains the final result expression, which references both common
154154
// expressions.
155-
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
155+
.select(($"_common_expr_0" +
156+
($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col"))
156157
.analyze
157158
)
158159
}
@@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
490491
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
491492
intercept[AssertionError](Optimizer.execute(wrongPlan))
492493
}
494+
495+
test("SPARK-50683: inline the common expression in With if used once") {
496+
val a = testRelation.output.head
497+
val exprDef = CommonExpressionDef(a + a)
498+
val exprRef = new CommonExpressionRef(exprDef)
499+
val expr = With(exprRef + 1, Seq(exprDef))
500+
val plan = testRelation.select(expr.as("col"))
501+
comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 1).as("col")))
502+
}
493503
}

0 commit comments

Comments
 (0)