diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala index 435fd239b364..4ffddaf77bbf 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala @@ -29,9 +29,9 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow, ArrowWritableColumnVector import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, NamedExpression, NaNvl, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExplainUtils, FilterExec, ProjectExec, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.hive.HiveUdfUtil import org.apache.spark.sql.types._ @@ -51,7 +51,7 @@ import scala.collection.mutable.ListBuffer * @param child * child plan */ -case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( +case class ColumnarPartialProjectExec(original: SparkPlan, child: SparkPlan)( replacedAliasUdf: Seq[Alias]) extends UnaryExecNode with ValidatablePlan { @@ -169,8 +169,12 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( // e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c return ValidationResult.failed("Number of RowToColumn columns is more than ProjectExec") } - if (!original.projectList.forall(validateExpression(_))) { - return ValidationResult.failed("Contains expression not supported") + original match { + case p: ProjectExec => + if (!p.projectList.forall(validateExpression(_))) { + return ValidationResult.failed("Contains expression not supported") + } + case _ => } if ( ExpressionUtils.hasComplexExpressions(original, GlutenConfig.get.fallbackExpressionsThreshold) @@ -290,7 +294,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( } } -object ColumnarPartialProjectExec { +object ColumnarPartialProjectExec extends PredicateHelper { val projectPrefix = "_SparkPartialProject" @@ -355,13 +359,26 @@ object ColumnarPartialProjectExec { } } - def create(original: ProjectExec): ProjectExecTransformer = { + def create(original: SparkPlan): UnaryTransformSupport = { val replacedAliasUdf: ListBuffer[Alias] = ListBuffer() - val newProjectList = original.projectList.map { - p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression] + val transformedPlan = original match { + case p: ProjectExec => + val newProjectList = p.projectList.map { + p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression] + } + val partialProject = + ColumnarPartialProjectExec(p, p.child)(replacedAliasUdf) + ProjectExecTransformer(newProjectList, partialProject) + case f: FilterExec => + val newCondition = splitConjunctivePredicates(f.condition) + .map(p => replaceExpressionUDF(p, replacedAliasUdf)) + .reduceLeftOption(And) + .orNull + val partialProject = + ColumnarPartialProjectExec(f, f.child)(replacedAliasUdf) + FilterExecTransformer(newCondition, partialProject) } - val partialProject = - ColumnarPartialProjectExec(original, original.child)(replacedAliasUdf.toSeq) - ProjectExecTransformer(newProjectList, partialProject) + + transformedPlan } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala index 73d1651e2fdb..61b3f10c1df7 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala @@ -20,12 +20,12 @@ import org.apache.gluten.execution.ColumnarPartialProjectExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} case class PartialProjectRule(spark: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { plan.transformUp { - case plan: ProjectExec => + case plan @ (_: ProjectExec | _: FilterExec) => val transformer = ColumnarPartialProjectExec.create(plan) if ( transformer.doValidate().ok() &&