From ad46db4ef671d8829dfffba2780ba0f6b4f4e43d Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Wed, 20 Nov 2024 13:22:51 -0800 Subject: [PATCH] [SPARK-50130][SQL][FOLLOWUP] Make Encoder generation lazy ### What changes were proposed in this pull request? Makes Encoder generation lazy. ### Why are the changes needed? The encoder with empty schema for lazy plan could cause unexpected behavior. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48829 from ueshin/issues/SPARK-50130/lazy_encoder. Authored-by: Takuya Ueshin Signed-off-by: Takuya Ueshin --- .../scala/org/apache/spark/sql/Dataset.scala | 35 +++++++------------ .../spark/sql/DataFrameSubquerySuite.scala | 15 +++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500a4c7c4d9bc..4766a74308a1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -95,13 +95,8 @@ private[sql] object Dataset { def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = sparkSession.withActive { val qe = sparkSession.sessionState.executePlan(logicalPlan) - val encoder = if (qe.isLazyAnalysis) { - RowEncoder.encoderFor(new StructType()) - } else { - qe.assertAnalyzed() - RowEncoder.encoderFor(qe.analyzed.schema) - } - new Dataset[Row](qe, encoder) + if (!qe.isLazyAnalysis) qe.assertAnalyzed() + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) } def ofRows( @@ -111,13 +106,8 @@ private[sql] object Dataset { sparkSession.withActive { val qe = new QueryExecution( sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode) - val encoder = if (qe.isLazyAnalysis) { - RowEncoder.encoderFor(new StructType()) - } else { - qe.assertAnalyzed() - RowEncoder.encoderFor(qe.analyzed.schema) - } - new Dataset[Row](qe, encoder) + if (!qe.isLazyAnalysis) qe.assertAnalyzed() + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ @@ -129,13 +119,8 @@ private[sql] object Dataset { : DataFrame = sparkSession.withActive { val qe = new QueryExecution( sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode) - val encoder = if (qe.isLazyAnalysis) { - RowEncoder.encoderFor(new StructType()) - } else { - qe.assertAnalyzed() - RowEncoder.encoderFor(qe.analyzed.schema) - } - new Dataset[Row](qe, encoder) + if (!qe.isLazyAnalysis) qe.assertAnalyzed() + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) } } @@ -229,7 +214,7 @@ private[sql] object Dataset { @Stable class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, - @DeveloperApi @Unstable @transient val encoder: Encoder[T]) + @transient encoderGenerator: () => Encoder[T]) extends api.Dataset[T] { type DS[U] = Dataset[U] @@ -252,6 +237,10 @@ class Dataset[T] private[sql]( // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. + private[sql] def this(queryExecution: QueryExecution, encoder: Encoder[T]) = { + this(queryExecution, () => encoder) + } + def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { this(sparkSession.sessionState.executePlan(logicalPlan), encoder) } @@ -274,6 +263,8 @@ class Dataset[T] private[sql]( } } + @DeveloperApi @Unstable @transient lazy val encoder: Encoder[T] = encoderGenerator() + /** * Expose the encoder as implicit so it can be used to construct new Dataset objects that have * the same external type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index 5a065d7e73b1c..d656c36ce842a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -54,11 +54,18 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { } test("unanalyzable expression") { - val exception = intercept[AnalysisException] { - spark.range(1).select($"id" === $"id".outer()).schema - } + val sub = spark.range(1).select($"id" === $"id".outer()) + + checkError( + intercept[AnalysisException](sub.schema), + condition = "UNANALYZABLE_EXPRESSION", + parameters = Map("expr" -> "\"outer(id)\""), + queryContext = + Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern)) + ) + checkError( - exception, + intercept[AnalysisException](sub.encoder), condition = "UNANALYZABLE_EXPRESSION", parameters = Map("expr" -> "\"outer(id)\""), queryContext = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d550d0f94f236..18af2fcb0ee73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -1205,7 +1205,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { dt ) checkError( - intercept[AnalysisException](spark.range(1).select(f())), + intercept[AnalysisException](spark.range(1).select(f()).encoder), condition = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER", sqlState = "0A000", parameters = Map("dataType" -> s"\"${dt.sql}\"")