diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala index 435fd239b364..a2121906cb44 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala @@ -29,9 +29,9 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow, ArrowWritableColumnVector import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, NamedExpression, NaNvl, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExplainUtils, FilterExec, ProjectExec, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.hive.HiveUdfUtil import org.apache.spark.sql.types._ @@ -40,18 +40,18 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import scala.collection.mutable.ListBuffer /** - * By rule , the project not offload-able that is changed to - * ProjectExecTransformer + ColumnarPartialProjectExec e.g. sum(myudf(a) + b + hash(c)), child is - * (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as _SparkPartialProject1), - * ProjectExecTransformer(_SparkPartialProject1 + b + hash(c)) + * By rule , the project/filter not offload-able that is changed to + * ProjectExecTransformer/FilterExecTransformer + ColumnarPartialProjectExec. e.g. sum(myudf(a) + b + * + hash(c)), child is (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as + * _SparkPartialProject1), ProjectExecTransformer(_SparkPartialProject1 + b + hash(c)) * * @param original - * extract the ScalaUDF from original project list as Alias in UnsafeProjection and + * extract the ScalaUDF from original project/filter list as Alias in UnsafeProjection and * AttributeReference in ColumnarPartialProjectExec output * @param child * child plan */ -case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( +case class ColumnarPartialProjectExec(original: SparkPlan, child: SparkPlan)( replacedAliasUdf: Seq[Alias]) extends UnaryExecNode with ValidatablePlan { @@ -77,7 +77,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( override def output: Seq[Attribute] = child.output ++ replacedAliasUdf.map(_.toAttribute) override def doCanonicalize(): ColumnarPartialProjectExec = { - val canonicalized = original.canonicalized.asInstanceOf[ProjectExec] + val canonicalized = original.canonicalized this.copy( original = canonicalized, child = child.canonicalized @@ -169,8 +169,12 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( // e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c return ValidationResult.failed("Number of RowToColumn columns is more than ProjectExec") } - if (!original.projectList.forall(validateExpression(_))) { - return ValidationResult.failed("Contains expression not supported") + original match { + case p: ProjectExec if !p.projectList.forall(validateExpression(_)) => + return ValidationResult.failed("Contains expression not supported") + case f: FilterExec if !validateExpression(f.condition) => + return ValidationResult.failed("Contains expression not supported") + case _ => } if ( ExpressionUtils.hasComplexExpressions(original, GlutenConfig.get.fallbackExpressionsThreshold) @@ -290,7 +294,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)( } } -object ColumnarPartialProjectExec { +object ColumnarPartialProjectExec extends PredicateHelper { val projectPrefix = "_SparkPartialProject" @@ -355,13 +359,27 @@ object ColumnarPartialProjectExec { } } - def create(original: ProjectExec): ProjectExecTransformer = { - val replacedAliasUdf: ListBuffer[Alias] = ListBuffer() - val newProjectList = original.projectList.map { - p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression] + def create(original: SparkPlan): UnaryTransformSupport = { + val transformedPlan = original match { + case p: ProjectExec => + val replacedAliasUdf: ListBuffer[Alias] = ListBuffer() + val newProjectList = p.projectList.map { + p => replaceExpressionUDF(p, replacedAliasUdf).asInstanceOf[NamedExpression] + } + val partialProject = + ColumnarPartialProjectExec(p, p.child)(replacedAliasUdf.toSeq) + ProjectExecTransformer(newProjectList, partialProject) + case f: FilterExec => + val replacedAliasUdf: ListBuffer[Alias] = ListBuffer() + val newCondition = splitConjunctivePredicates(f.condition) + .map(p => replaceExpressionUDF(p, replacedAliasUdf)) + .reduceLeftOption(And) + .orNull + val partialProject = + ColumnarPartialProjectExec(f, f.child)(replacedAliasUdf.toSeq) + FilterExecTransformer(newCondition, partialProject) } - val partialProject = - ColumnarPartialProjectExec(original, original.child)(replacedAliasUdf.toSeq) - ProjectExecTransformer(newProjectList, partialProject) + + transformedPlan } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala index 73d1651e2fdb..61b3f10c1df7 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/PartialProjectRule.scala @@ -20,12 +20,12 @@ import org.apache.gluten.execution.ColumnarPartialProjectExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} case class PartialProjectRule(spark: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { plan.transformUp { - case plan: ProjectExec => + case plan @ (_: ProjectExec | _: FilterExec) => val transformer = ColumnarPartialProjectExec.create(plan) if ( transformer.doValidate().ok() && diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index c8a18d6881ec..7114c9bc3203 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite, GlutenSessionExtensionSuite, TestFileSourceScanExecTransformer} import org.apache.spark.sql.gluten.GlutenFallbackSuite -import org.apache.spark.sql.hive.execution.GlutenHiveSQLQuerySuite +import org.apache.spark.sql.hive.execution.{GlutenHiveSQLQuerySuite, GlutenHiveUDFSuite} import org.apache.spark.sql.sources._ // Some settings' line length exceeds 100 @@ -1230,6 +1230,7 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenXPathFunctionsSuite] enableSuite[GlutenFallbackSuite] enableSuite[GlutenHiveSQLQuerySuite] + enableSuite[GlutenHiveUDFSuite] enableSuite[GlutenCollapseProjectExecTransformerSuite] enableSuite[GlutenSparkSessionExtensionSuite] enableSuite[GlutenGroupBasedDeleteFromTableSuite] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveSQLQuerySuiteBase.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveSQLQuerySuiteBase.scala index c8540647d3fa..80ce4645644f 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveSQLQuerySuiteBase.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveSQLQuerySuiteBase.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.sql.hive.execution -import org.apache.gluten.execution.TransformSupport - import org.apache.spark.SparkConf import org.apache.spark.internal.config import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql.{DataFrame, GlutenSQLTestsTrait, SparkSession} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.internal.SQLConf @@ -90,7 +89,7 @@ abstract class GlutenHiveSQLQuerySuiteBase extends GlutenSQLTestsTrait { conf.set("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastore;create=true") } - def checkOperatorMatch[T <: TransformSupport](df: DataFrame)(implicit tag: ClassTag[T]): Unit = { + def checkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag: ClassTag[T]): Unit = { val executedPlan = getExecutedPlan(df) assert(executedPlan.exists(plan => plan.getClass == tag.runtimeClass)) } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala index 6365305140e5..0d2fda89658e 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala @@ -16,77 +16,29 @@ */ package org.apache.spark.sql.hive.execution -import org.apache.gluten.execution.CustomerUDF +import org.apache.gluten.execution.{ColumnarPartialProjectExec, CustomerUDF} -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.internal.config -import org.apache.spark.internal.config.UI.UI_ENABLED -import org.apache.spark.sql.{GlutenTestsBaseTrait, QueryTest, SparkSession} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} -import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH -import org.apache.spark.sql.test.SQLTestUtils - -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkConf +import org.apache.spark.sql.Row import java.io.File -trait GlutenTestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { - override protected val enableAutoThreadAudit = false - -} +class GlutenHiveUDFSuite extends GlutenHiveSQLQuerySuiteBase { -object GlutenTestHive - extends TestHiveContext( - new SparkContext( - System.getProperty("spark.sql.test.master", "local[1]"), - "TestSQLContext", - new SparkConf() - .set("spark.sql.test", "") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - .set( - HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, - "org.apache.spark.sql.hive.execution.PairSerDe") - .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) - // SPARK-8910 - .set(UI_ENABLED, false) - .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) - // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes - // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. - .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") - .set("spark.driver.memory", "1G") - .set("spark.sql.adaptive.enabled", "true") - .set("spark.sql.shuffle.partitions", "1") - .set("spark.sql.files.maxPartitionBytes", "134217728") - .set("spark.memory.offHeap.enabled", "true") - .set("spark.memory.offHeap.size", "1024MB") - .set("spark.plugins", "org.apache.gluten.GlutenPlugin") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - ), - false - ) {} + override def sparkConf: SparkConf = { + defaultSparkConf + .set("spark.plugins", "org.apache.gluten.GlutenPlugin") + .set("spark.default.parallelism", "1") + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1024MB") + } -class GlutenHiveUDFSuite - extends QueryTest - with GlutenTestHiveSingleton - with SQLTestUtils - with GlutenTestsBaseTrait { - override protected val spark: SparkSession = GlutenTestHive.sparkSession - protected val hiveContext: TestHiveContext = GlutenTestHive - protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client + def withTempFunction(funcName: String)(f: => Unit): Unit = { + try f + finally sql(s"DROP TEMPORARY FUNCTION IF EXISTS $funcName") + } - override protected def beforeAll(): Unit = { + override def beforeAll(): Unit = { super.beforeAll() val table = "lineitem" val tableDir = @@ -97,43 +49,50 @@ class GlutenHiveUDFSuite tableDF.createOrReplaceTempView(table) } - override protected def afterAll(): Unit = { - try { - hiveContext.reset() - } finally { - super.afterAll() - } - } - - override protected def shouldRun(testName: String): Boolean = { - false + override def afterAll(): Unit = { + super.afterAll() } test("customer udf") { - sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'") - val df = spark.sql("""select testUDF(l_comment) - | from lineitem""".stripMargin) - df.show() - print(df.queryExecution.executedPlan) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") - hiveContext.reset() + withTempFunction("testUDF") { + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'") + val df = sql("select l_partkey, testUDF(l_comment) from lineitem") + df.show() + checkOperatorMatch[ColumnarPartialProjectExec](df) + } } test("customer udf wrapped in function") { - sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'") - val df = spark.sql("""select hash(testUDF(l_comment)) - | from lineitem""".stripMargin) - df.show() - print(df.queryExecution.executedPlan) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") - hiveContext.reset() + withTempFunction("testUDF") { + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'") + val df = sql("select l_partkey, hash(testUDF(l_comment)) from lineitem") + df.show() + checkOperatorMatch[ColumnarPartialProjectExec](df) + } } test("example") { - spark.sql("CREATE TEMPORARY FUNCTION testUDF AS 'org.apache.hadoop.hive.ql.udf.UDFSubstr';") - spark.sql("select testUDF('l_commen', 1, 5)").show() - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") - hiveContext.reset() + withTempFunction("testUDF") { + sql("CREATE TEMPORARY FUNCTION testUDF AS 'org.apache.hadoop.hive.ql.udf.UDFSubstr';") + val df = sql("select testUDF('l_commen', 1, 5)") + df.show() + // It should not be converted to ColumnarPartialProjectExec, since + // the UDF need all the columns in child output. + assert(!getExecutedPlan(df).exists { + case _: ColumnarPartialProjectExec => true + case _ => false + }) + } } + test("udf in filter") { + withTempFunction("testUDF") { + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[CustomerUDF].getName}'") + val df = sql(""" + |select l_partkey from lineitem where hash(testUDF(l_comment)) = 1961715824 + |""".stripMargin) + checkAnswer(df, Seq(Row(1552))) + checkOperatorMatch[ColumnarPartialProjectExec](df) + } + } }