Skip to content

Commit

Permalink
ColumnaPartialProject supports udfs in filter
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Feb 6, 2025
1 parent 546e63a commit db0fb19
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow, ArrowWritableColumnVector

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, NamedExpression, NaNvl, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{ExplainUtils, FilterExec, ProjectExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.hive.HiveUdfUtil
import org.apache.spark.sql.types._
Expand All @@ -51,7 +51,7 @@ import scala.collection.mutable.ListBuffer
* @param child
* child plan
*/
case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
case class ColumnarPartialProjectExec(original: SparkPlan, child: SparkPlan)(
replacedAliasUdf: Seq[Alias])
extends UnaryExecNode
with ValidatablePlan {
Expand Down Expand Up @@ -169,8 +169,12 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
// e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c
return ValidationResult.failed("Number of RowToColumn columns is more than ProjectExec")
}
if (!original.projectList.forall(validateExpression(_))) {
return ValidationResult.failed("Contains expression not supported")
original match {
case p: ProjectExec =>
if (!p.projectList.forall(validateExpression(_))) {
return ValidationResult.failed("Contains expression not supported")
}
case _ =>
}
if (
ExpressionUtils.hasComplexExpressions(original, GlutenConfig.get.fallbackExpressionsThreshold)
Expand Down Expand Up @@ -290,7 +294,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
}
}

object ColumnarPartialProjectExec {
object ColumnarPartialProjectExec extends PredicateHelper {

val projectPrefix = "_SparkPartialProject"

Expand Down Expand Up @@ -355,13 +359,26 @@ object ColumnarPartialProjectExec {
}
}

def create(original: ProjectExec): ProjectExecTransformer = {
def create(original: SparkPlan): UnaryTransformSupport = {
val replacedAliasUdf: ListBuffer[Alias] = ListBuffer()
val newProjectList = original.projectList.map {
p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression]
val transformedPlan = original match {
case p: ProjectExec =>
val newProjectList = p.projectList.map {
p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression]
}
val partialProject =
ColumnarPartialProjectExec(p, p.child)(replacedAliasUdf)
ProjectExecTransformer(newProjectList, partialProject)
case f: FilterExec =>
val newCondition = splitConjunctivePredicates(f.condition)
.map(p => replaceExpressionUDF(p, replacedAliasUdf))
.reduceLeftOption(And)
.orNull
val partialProject =
ColumnarPartialProjectExec(f, f.child)(replacedAliasUdf)
FilterExecTransformer(newCondition, partialProject)
}
val partialProject =
ColumnarPartialProjectExec(original, original.child)(replacedAliasUdf.toSeq)
ProjectExecTransformer(newProjectList, partialProject)

transformedPlan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import org.apache.gluten.execution.ColumnarPartialProjectExec

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}

case class PartialProjectRule(spark: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case plan: ProjectExec =>
case plan @ (_: ProjectExec | _: FilterExec) =>
val transformer = ColumnarPartialProjectExec.create(plan)
if (
transformer.doValidate().ok() &&
Expand Down

0 comments on commit db0fb19

Please sign in to comment.