Skip to content

Commit 54f75be

Browse files
update test to be parametrized
Signed-off-by: cindyyuanjiang <[email protected]>
1 parent 5d03b81 commit 54f75be

File tree

1 file changed

+142
-130
lines changed

1 file changed

+142
-130
lines changed

core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala

+142-130
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OpRef}
2626
import com.nvidia.spark.rapids.tool.qualification._
2727
import org.scalatest.Matchers.{be, contain, convertToAnyShouldWrapper}
2828
import org.scalatest.exceptions.TestFailedException
29+
import org.scalatest.prop.TableDrivenPropertyChecks._
30+
import org.scalatest.prop.TableFor2
2931

3032
import org.apache.spark.sql.{DataFrame, SparkSession, TrampolineUtil}
3133
import org.apache.spark.sql.execution.ui.SQLPlanMetric
@@ -1050,141 +1052,151 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
10501052
}
10511053

10521054
/**
1053-
* Generates a sequence of app names and their corresponding SparkSession operations.
1054-
* Each application performs a transformation on DataFrames and writes/reads them via Parquet.
1055+
* Helper function to write a DataFrame to Parquet and then read it back.
10551056
*
1056-
* @param parquetOutputLoc The output directory where Parquet files will be stored.
1057-
* @return A sequence of tuples containing the app name and a function that takes a SparkSession
1058-
* and returns a DataFrame.
1057+
* @param spark The SparkSession instance.
1058+
* @param df The input DataFrame to be written.
1059+
* @param path The file path to store the Parquet file.
1060+
* @return A DataFrame read from the Parquet file.
10591061
*/
1060-
private def appNamesAndSparkSessions(parquetOutputLoc: File):
1061-
Seq[(String, SparkSession => DataFrame)] = {
1062-
/**
1063-
* Helper function to write a DataFrame to Parquet and then read it back.
1064-
*
1065-
* @param spark The SparkSession instance.
1066-
* @param df The input DataFrame to be written.
1067-
* @param path The file path to store the Parquet file.
1068-
* @return A DataFrame read from the Parquet file.
1069-
*/
1070-
def writeAndReadParquet(spark: SparkSession, df: DataFrame, path: String): DataFrame = {
1071-
df.write.parquet(path)
1072-
spark.read.parquet(path)
1073-
}
1062+
private def writeAndReadParquet(spark: SparkSession, df: DataFrame, path: String): DataFrame = {
1063+
df.write.parquet(path)
1064+
spark.read.parquet(path)
1065+
}
10741066

1075-
Seq(// MonthsBetween is supported in ProjectExec
1076-
("MonthsBetweenSupportedInProject", { spark =>
1077-
import spark.implicits._
1078-
val df1 = Seq(("2024-12-01", "2024-01-01"),
1079-
("2024-12-01", "2023-12-01"),
1080-
("2024-12-01", "2024-12-01")).toDF("date1", "date2")
1081-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1082-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/monthsbetweentesttext")
1083-
// months_between should be part of ProjectExec
1084-
df2.select(months_between(df2("date1"), df2("date2")))
1085-
}),
1086-
// TruncDate is supported in ProjectExec
1087-
("TruncDateSupportedInProject", { spark =>
1088-
import spark.implicits._
1089-
val df1 = Seq("2024-12-15", "2024-01-10", "2023-11-05").toDF("date")
1090-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1091-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/truncdatetesttext")
1092-
// trunc should be part of ProjectExec
1093-
df2.select(trunc(df2("date"), "month"))
1094-
}),
1095-
// TruncTimestamp is supported in ProjectExec
1096-
("TruncTimestampSupportedInProject", { spark =>
1097-
import spark.implicits._
1098-
val data = Seq("2024-12-15 14:30:45",
1099-
"2024-01-10 08:15:00",
1100-
"2023-11-05 20:45:30").toDF("timestamp")
1101-
val df1 = data.withColumn("timestamp", to_timestamp(col("timestamp")))
1102-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1103-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/trunctimestamptesttext")
1104-
// date_trunc should be part of ProjectExec
1105-
df2.select(date_trunc("month", df2("timestamp")))
1106-
}),
1107-
// Ceil is supported in ProjectExec
1108-
("CeilSupportedInProject", { spark =>
1109-
import spark.implicits._
1110-
import org.apache.spark.sql.types.StringType
1111-
val df1 = Seq(9.9, 10.2, 11.6, 12.5).toDF("value")
1112-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1113-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/ceiltesttext")
1114-
// ceil should be part of ProjectExec
1115-
df2.select(df2("value").cast(StringType), ceil(df2("value")), df2("value"))
1116-
}),
1117-
// Translate is supported in ProjectExec
1118-
("TranslateSupportedInProject", { spark =>
1119-
import spark.implicits._
1120-
val df1 = Seq("", "abc", "ABC", "AaBbCc").toDF("value")
1121-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1122-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/translatetesttext")
1123-
// translate should be part of ProjectExec
1124-
df2.select(translate(df2("value"), "ABC", "123"))
1125-
}),
1126-
// Timestamp functions are supported in ProjectExec
1127-
("TimestampFunctionsSupportedInProject", { spark =>
1128-
import spark.implicits._
1129-
val init_df = Seq((1230219000123123L, 1230219000123L, 1230219000.123))
1130-
val df1 = init_df.toDF("micro", "millis", "seconds")
1131-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1132-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/timestampfunctesttext")
1133-
// timestamp functions should be part of ProjectExec
1134-
df2.selectExpr("timestamp_micros(micro)", "timestamp_millis(millis)",
1135-
"timestamp_seconds(seconds)")
1136-
}),
1137-
// Flatten is supported in ProjectExec
1138-
("FlattenSupportedInProject", { spark =>
1139-
import spark.implicits._
1140-
val df1 = Seq(Seq(Seq(1, 2), Seq(3, 4))).toDF("value")
1141-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1142-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/flattentesttext")
1143-
// flatten should be part of ProjectExec
1144-
df2.select(flatten(df2("value")))
1145-
}),
1146-
// Xxhash64 is supported in ProjectExec
1147-
("Xxhash64SupportedInProject", { spark =>
1148-
import spark.implicits._
1149-
val df1 = Seq("spark", "", "abc").toDF("value")
1150-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1151-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/xxhash64testtext")
1152-
// xxhash64 should be part of ProjectExec
1153-
df2.select(xxhash64(df2("value")))
1154-
}),
1155-
// MapFromArrays is supported in ProjectExec
1156-
("MapFromArraysSupportedInProject", { spark =>
1157-
import spark.implicits._
1158-
val df1 = Seq((Array("a", "b", "c"), Array(1, 2, 3)),
1159-
(Array("x", "y", "z"), Array(10, 20, 30))).toDF("keys", "values")
1160-
// write df1 to parquet to transform LocalTableScan to ProjectExec
1161-
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/mapfromarraystesttext")
1162-
// map_from_arrays should be part of ProjectExec
1163-
df2.select(map_from_arrays(df2("keys"), df2("values")).as("map"))
1164-
})
1165-
)
1067+
/**
1068+
* Table-driven test cases for verifying Spark SQL expressions in ProjectExec.
1069+
*/
1070+
val projectExecTestCases: TableFor2[String, (File => (SparkSession => DataFrame))] = Table(
1071+
("Expression", "FileToSparkSession"),
1072+
// MonthsBetween is supported in ProjectExec
1073+
("MonthsBetween", { parquetOutputLoc => { spark =>
1074+
import spark.implicits._
1075+
val df1 = Seq(("2024-12-01", "2024-01-01"),
1076+
("2024-12-01", "2023-12-01"),
1077+
("2024-12-01", "2024-12-01")).toDF("date1", "date2")
1078+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1079+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/testtext")
1080+
// months_between should be part of ProjectExec
1081+
df2.select(months_between(df2("date1"), df2("date2")))
1082+
}}),
1083+
// TruncDate is supported in ProjectExec
1084+
("TruncDate", { parquetOutputLoc => { spark =>
1085+
import spark.implicits._
1086+
val df1 = Seq("2024-12-15", "2024-01-10", "2023-11-05").toDF("date")
1087+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1088+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/testtext")
1089+
// trunc should be part of ProjectExec
1090+
df2.select(trunc(df2("date"), "month"))
1091+
}}),
1092+
// TruncTimestamp is supported in ProjectExec
1093+
("TruncTimestamp", { parquetOutputLoc => { spark =>
1094+
import spark.implicits._
1095+
val data = Seq("2024-12-15 14:30:45",
1096+
"2024-01-10 08:15:00",
1097+
"2023-11-05 20:45:30").toDF("timestamp")
1098+
val df1 = data.withColumn("timestamp", to_timestamp(col("timestamp")))
1099+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1100+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/trunctimestamptesttext")
1101+
// date_trunc should be part of ProjectExec
1102+
df2.select(date_trunc("month", df2("timestamp")))
1103+
}}),
1104+
// Ceil is supported in ProjectExec
1105+
("Ceil", { parquetOutputLoc => { spark =>
1106+
import spark.implicits._
1107+
import org.apache.spark.sql.types.StringType
1108+
val df1 = Seq(9.9, 10.2, 11.6, 12.5).toDF("value")
1109+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1110+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/ceiltesttext")
1111+
// ceil should be part of ProjectExec
1112+
df2.select(df2("value").cast(StringType), ceil(df2("value")), df2("value"))
1113+
}}),
1114+
// Translate is supported in ProjectExec
1115+
("Translate", { parquetOutputLoc => { spark =>
1116+
import spark.implicits._
1117+
val df1 = Seq("", "abc", "ABC", "AaBbCc").toDF("value")
1118+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1119+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/translatetesttext")
1120+
// translate should be part of ProjectExec
1121+
df2.select(translate(df2("value"), "ABC", "123"))
1122+
}}),
1123+
// Timestamp functions are supported in ProjectExec
1124+
("TimestampFunctions", { parquetOutputLoc => { spark =>
1125+
import spark.implicits._
1126+
val init_df = Seq((1230219000123123L, 1230219000123L, 1230219000.123))
1127+
val df1 = init_df.toDF("micro", "millis", "seconds")
1128+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1129+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/timestampfunctesttext")
1130+
// timestamp functions should be part of ProjectExec
1131+
df2.selectExpr("timestamp_micros(micro)", "timestamp_millis(millis)",
1132+
"timestamp_seconds(seconds)")
1133+
}}),
1134+
// Flatten is supported in ProjectExec
1135+
("Flatten", { parquetOutputLoc => { spark =>
1136+
import spark.implicits._
1137+
val df1 = Seq(Seq(Seq(1, 2), Seq(3, 4))).toDF("value")
1138+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1139+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/flattentesttext")
1140+
// flatten should be part of ProjectExec
1141+
df2.select(flatten(df2("value")))
1142+
}}),
1143+
// Xxhash64 is supported in ProjectExec
1144+
("Xxhash64", { parquetOutputLoc => { spark =>
1145+
import spark.implicits._
1146+
val df1 = Seq("spark", "", "abc").toDF("value")
1147+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1148+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/xxhash64testtext")
1149+
// xxhash64 should be part of ProjectExec
1150+
df2.select(xxhash64(df2("value")))
1151+
}}),
1152+
// MapFromArrays is supported in ProjectExec
1153+
("MapFromArrays", { parquetOutputLoc => { spark =>
1154+
import spark.implicits._
1155+
val df1 = Seq((Array("a", "b", "c"), Array(1, 2, 3)),
1156+
(Array("x", "y", "z"), Array(10, 20, 30))).toDF("keys", "values")
1157+
// write df1 to parquet to transform LocalTableScan to ProjectExec
1158+
val df2 = writeAndReadParquet(spark, df1, s"$parquetOutputLoc/mapfromarraystesttext")
1159+
// map_from_arrays should be part of ProjectExec
1160+
df2.select(map_from_arrays(df2("keys"), df2("values")).as("map"))
1161+
}})
1162+
)
1163+
1164+
/**
1165+
* Tests whether a given Spark SQL expression is supported in ProjectExec.
1166+
*
1167+
* @param appName Name of the Spark application.
1168+
* @param fileToSparkSession Function that maps a temporary Parquet directory to a function that
1169+
* takes a SparkSession and returns a DataFrame.
1170+
* @param parquetOutputLoc Temporary directory used for writing and reading Parquet files.
1171+
*/
1172+
private def testExpressionInProjectExec(appName: String,
1173+
fileToSparkSession: File => (SparkSession => DataFrame),
1174+
parquetOutputLoc: File): Unit = {
1175+
TrampolineUtil.withTempDir { eventLogDir =>
1176+
val (eventLog, _) =
1177+
ToolTestUtils.generateEventLog(eventLogDir, appName)(fileToSparkSession(parquetOutputLoc))
1178+
val pluginTypeChecker = new PluginTypeChecker()
1179+
val app = createAppFromEventlog(eventLog)
1180+
assert(app.sqlPlans.size == 2)
1181+
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
1182+
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app)
1183+
}
1184+
verifyExecToStageMapping(parsedPlans.toSeq, app)
1185+
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
1186+
val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen"))
1187+
assert(wholeStages.size == 1)
1188+
assert(wholeStages.forall(_.duration.nonEmpty))
1189+
val allChildren = wholeStages.flatMap(_.children).flatten
1190+
val projects = allChildren.filter(_.exec == "Project")
1191+
assertSizeAndSupported(1, projects)
1192+
}
11661193
}
11671194

1168-
test("Expressions supported in ProjectExec") {
1169-
TrampolineUtil.withTempDir { parquetoutputLoc =>
1170-
for ((appName, sparkSession) <- appNamesAndSparkSessions(parquetoutputLoc)) {
1171-
TrampolineUtil.withTempDir { eventLogDir =>
1172-
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, appName)(sparkSession)
1173-
val pluginTypeChecker = new PluginTypeChecker()
1174-
val app = createAppFromEventlog(eventLog)
1175-
assert(app.sqlPlans.size == 2)
1176-
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
1177-
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app)
1178-
}
1179-
verifyExecToStageMapping(parsedPlans.toSeq, app)
1180-
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
1181-
val wholeStages = allExecInfo.filter(_.exec.contains("WholeStageCodegen"))
1182-
assert(wholeStages.size == 1)
1183-
assert(wholeStages.forall(_.duration.nonEmpty))
1184-
val allChildren = wholeStages.flatMap(_.children).flatten
1185-
val projects = allChildren.filter(_.exec == "Project")
1186-
assertSizeAndSupported(1, projects)
1187-
}
1195+
forAll(projectExecTestCases) { (exprName, fileToSparkSession) =>
1196+
test(s"$exprName is supported in ProjectExec") {
1197+
TrampolineUtil.withTempDir { parquetOutputLoc =>
1198+
testExpressionInProjectExec(s"{$exprName}SupportedInProjectExec", fileToSparkSession,
1199+
parquetOutputLoc)
11881200
}
11891201
}
11901202
}

0 commit comments

Comments
 (0)