@@ -26,6 +26,8 @@ import com.nvidia.spark.rapids.tool.planparser.ops.{ExprOpRef, OpRef}
26
26
import com .nvidia .spark .rapids .tool .qualification ._
27
27
import org .scalatest .Matchers .{be , contain , convertToAnyShouldWrapper }
28
28
import org .scalatest .exceptions .TestFailedException
29
+ import org .scalatest .prop .TableDrivenPropertyChecks ._
30
+ import org .scalatest .prop .TableFor2
29
31
30
32
import org .apache .spark .sql .{DataFrame , SparkSession , TrampolineUtil }
31
33
import org .apache .spark .sql .execution .ui .SQLPlanMetric
@@ -1050,141 +1052,151 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
1050
1052
}
1051
1053
1052
1054
/**
1053
- * Generates a sequence of app names and their corresponding SparkSession operations.
1054
- * Each application performs a transformation on DataFrames and writes/reads them via Parquet.
1055
+ * Helper function to write a DataFrame to Parquet and then read it back.
1055
1056
*
1056
- * @param parquetOutputLoc The output directory where Parquet files will be stored.
1057
- * @return A sequence of tuples containing the app name and a function that takes a SparkSession
1058
- * and returns a DataFrame.
1057
+ * @param spark The SparkSession instance.
1058
+ * @param df The input DataFrame to be written.
1059
+ * @param path The file path to store the Parquet file.
1060
+ * @return A DataFrame read from the Parquet file.
1059
1061
*/
1060
- private def appNamesAndSparkSessions (parquetOutputLoc : File ):
1061
- Seq [(String , SparkSession => DataFrame )] = {
1062
- /**
1063
- * Helper function to write a DataFrame to Parquet and then read it back.
1064
- *
1065
- * @param spark The SparkSession instance.
1066
- * @param df The input DataFrame to be written.
1067
- * @param path The file path to store the Parquet file.
1068
- * @return A DataFrame read from the Parquet file.
1069
- */
1070
- def writeAndReadParquet (spark : SparkSession , df : DataFrame , path : String ): DataFrame = {
1071
- df.write.parquet(path)
1072
- spark.read.parquet(path)
1073
- }
1062
+ private def writeAndReadParquet (spark : SparkSession , df : DataFrame , path : String ): DataFrame = {
1063
+ df.write.parquet(path)
1064
+ spark.read.parquet(path)
1065
+ }
1074
1066
1075
- Seq (// MonthsBetween is supported in ProjectExec
1076
- (" MonthsBetweenSupportedInProject" , { spark =>
1077
- import spark .implicits ._
1078
- val df1 = Seq ((" 2024-12-01" , " 2024-01-01" ),
1079
- (" 2024-12-01" , " 2023-12-01" ),
1080
- (" 2024-12-01" , " 2024-12-01" )).toDF(" date1" , " date2" )
1081
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1082
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/monthsbetweentesttext " )
1083
- // months_between should be part of ProjectExec
1084
- df2.select(months_between(df2(" date1" ), df2(" date2" )))
1085
- }),
1086
- // TruncDate is supported in ProjectExec
1087
- (" TruncDateSupportedInProject" , { spark =>
1088
- import spark .implicits ._
1089
- val df1 = Seq (" 2024-12-15" , " 2024-01-10" , " 2023-11-05" ).toDF(" date" )
1090
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1091
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/truncdatetesttext " )
1092
- // trunc should be part of ProjectExec
1093
- df2.select(trunc(df2(" date" ), " month" ))
1094
- }),
1095
- // TruncTimestamp is supported in ProjectExec
1096
- (" TruncTimestampSupportedInProject" , { spark =>
1097
- import spark .implicits ._
1098
- val data = Seq (" 2024-12-15 14:30:45" ,
1099
- " 2024-01-10 08:15:00" ,
1100
- " 2023-11-05 20:45:30" ).toDF(" timestamp" )
1101
- val df1 = data.withColumn(" timestamp" , to_timestamp(col(" timestamp" )))
1102
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1103
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/trunctimestamptesttext " )
1104
- // date_trunc should be part of ProjectExec
1105
- df2.select(date_trunc(" month" , df2(" timestamp" )))
1106
- }),
1107
- // Ceil is supported in ProjectExec
1108
- (" CeilSupportedInProject" , { spark =>
1109
- import spark .implicits ._
1110
- import org .apache .spark .sql .types .StringType
1111
- val df1 = Seq (9.9 , 10.2 , 11.6 , 12.5 ).toDF(" value" )
1112
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1113
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/ceiltesttext " )
1114
- // ceil should be part of ProjectExec
1115
- df2.select(df2(" value" ).cast(StringType ), ceil(df2(" value" )), df2(" value" ))
1116
- }),
1117
- // Translate is supported in ProjectExec
1118
- (" TranslateSupportedInProject" , { spark =>
1119
- import spark .implicits ._
1120
- val df1 = Seq (" " , " abc" , " ABC" , " AaBbCc" ).toDF(" value" )
1121
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1122
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/translatetesttext " )
1123
- // translate should be part of ProjectExec
1124
- df2.select(translate(df2(" value" ), " ABC" , " 123" ))
1125
- }),
1126
- // Timestamp functions are supported in ProjectExec
1127
- (" TimestampFunctionsSupportedInProject" , { spark =>
1128
- import spark .implicits ._
1129
- val init_df = Seq ((1230219000123123L , 1230219000123L , 1230219000.123 ))
1130
- val df1 = init_df.toDF(" micro" , " millis" , " seconds" )
1131
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1132
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/timestampfunctesttext " )
1133
- // timestamp functions should be part of ProjectExec
1134
- df2.selectExpr(" timestamp_micros(micro)" , " timestamp_millis(millis)" ,
1135
- " timestamp_seconds(seconds)" )
1136
- }),
1137
- // Flatten is supported in ProjectExec
1138
- (" FlattenSupportedInProject" , { spark =>
1139
- import spark .implicits ._
1140
- val df1 = Seq (Seq (Seq (1 , 2 ), Seq (3 , 4 ))).toDF(" value" )
1141
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1142
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/flattentesttext " )
1143
- // flatten should be part of ProjectExec
1144
- df2.select(flatten(df2(" value" )))
1145
- }),
1146
- // Xxhash64 is supported in ProjectExec
1147
- (" Xxhash64SupportedInProject" , { spark =>
1148
- import spark .implicits ._
1149
- val df1 = Seq (" spark" , " " , " abc" ).toDF(" value" )
1150
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1151
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/xxhash64testtext " )
1152
- // xxhash64 should be part of ProjectExec
1153
- df2.select(xxhash64(df2(" value" )))
1154
- }),
1155
- // MapFromArrays is supported in ProjectExec
1156
- (" MapFromArraysSupportedInProject" , { spark =>
1157
- import spark .implicits ._
1158
- val df1 = Seq ((Array (" a" , " b" , " c" ), Array (1 , 2 , 3 )),
1159
- (Array (" x" , " y" , " z" ), Array (10 , 20 , 30 ))).toDF(" keys" , " values" )
1160
- // write df1 to parquet to transform LocalTableScan to ProjectExec
1161
- val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/mapfromarraystesttext " )
1162
- // map_from_arrays should be part of ProjectExec
1163
- df2.select(map_from_arrays(df2(" keys" ), df2(" values" )).as(" map" ))
1164
- })
1165
- )
1067
+ /**
1068
+ * Table-driven test cases for verifying Spark SQL expressions in ProjectExec.
1069
+ */
1070
+ val projectExecTestCases : TableFor2 [String , (File => (SparkSession => DataFrame ))] = Table (
1071
+ (" Expression" , " FileToSparkSession" ),
1072
+ // MonthsBetween is supported in ProjectExec
1073
+ (" MonthsBetween" , { parquetOutputLoc => { spark =>
1074
+ import spark .implicits ._
1075
+ val df1 = Seq ((" 2024-12-01" , " 2024-01-01" ),
1076
+ (" 2024-12-01" , " 2023-12-01" ),
1077
+ (" 2024-12-01" , " 2024-12-01" )).toDF(" date1" , " date2" )
1078
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1079
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/testtext " )
1080
+ // months_between should be part of ProjectExec
1081
+ df2.select(months_between(df2(" date1" ), df2(" date2" )))
1082
+ }}),
1083
+ // TruncDate is supported in ProjectExec
1084
+ (" TruncDate" , { parquetOutputLoc => { spark =>
1085
+ import spark .implicits ._
1086
+ val df1 = Seq (" 2024-12-15" , " 2024-01-10" , " 2023-11-05" ).toDF(" date" )
1087
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1088
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/testtext " )
1089
+ // trunc should be part of ProjectExec
1090
+ df2.select(trunc(df2(" date" ), " month" ))
1091
+ }}),
1092
+ // TruncTimestamp is supported in ProjectExec
1093
+ (" TruncTimestamp" , { parquetOutputLoc => { spark =>
1094
+ import spark .implicits ._
1095
+ val data = Seq (" 2024-12-15 14:30:45" ,
1096
+ " 2024-01-10 08:15:00" ,
1097
+ " 2023-11-05 20:45:30" ).toDF(" timestamp" )
1098
+ val df1 = data.withColumn(" timestamp" , to_timestamp(col(" timestamp" )))
1099
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1100
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/trunctimestamptesttext " )
1101
+ // date_trunc should be part of ProjectExec
1102
+ df2.select(date_trunc(" month" , df2(" timestamp" )))
1103
+ }}),
1104
+ // Ceil is supported in ProjectExec
1105
+ (" Ceil" , { parquetOutputLoc => { spark =>
1106
+ import spark .implicits ._
1107
+ import org .apache .spark .sql .types .StringType
1108
+ val df1 = Seq (9.9 , 10.2 , 11.6 , 12.5 ).toDF(" value" )
1109
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1110
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/ceiltesttext " )
1111
+ // ceil should be part of ProjectExec
1112
+ df2.select(df2(" value" ).cast(StringType ), ceil(df2(" value" )), df2(" value" ))
1113
+ }}),
1114
+ // Translate is supported in ProjectExec
1115
+ (" Translate" , { parquetOutputLoc => { spark =>
1116
+ import spark .implicits ._
1117
+ val df1 = Seq (" " , " abc" , " ABC" , " AaBbCc" ).toDF(" value" )
1118
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1119
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/translatetesttext " )
1120
+ // translate should be part of ProjectExec
1121
+ df2.select(translate(df2(" value" ), " ABC" , " 123" ))
1122
+ }}),
1123
+ // Timestamp functions are supported in ProjectExec
1124
+ (" TimestampFunctions" , { parquetOutputLoc => { spark =>
1125
+ import spark .implicits ._
1126
+ val init_df = Seq ((1230219000123123L , 1230219000123L , 1230219000.123 ))
1127
+ val df1 = init_df.toDF(" micro" , " millis" , " seconds" )
1128
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1129
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/timestampfunctesttext " )
1130
+ // timestamp functions should be part of ProjectExec
1131
+ df2.selectExpr(" timestamp_micros(micro)" , " timestamp_millis(millis)" ,
1132
+ " timestamp_seconds(seconds)" )
1133
+ }}),
1134
+ // Flatten is supported in ProjectExec
1135
+ (" Flatten" , { parquetOutputLoc => { spark =>
1136
+ import spark .implicits ._
1137
+ val df1 = Seq (Seq (Seq (1 , 2 ), Seq (3 , 4 ))).toDF(" value" )
1138
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1139
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/flattentesttext " )
1140
+ // flatten should be part of ProjectExec
1141
+ df2.select(flatten(df2(" value" )))
1142
+ }}),
1143
+ // Xxhash64 is supported in ProjectExec
1144
+ (" Xxhash64" , { parquetOutputLoc => { spark =>
1145
+ import spark .implicits ._
1146
+ val df1 = Seq (" spark" , " " , " abc" ).toDF(" value" )
1147
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1148
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/xxhash64testtext " )
1149
+ // xxhash64 should be part of ProjectExec
1150
+ df2.select(xxhash64(df2(" value" )))
1151
+ }}),
1152
+ // MapFromArrays is supported in ProjectExec
1153
+ (" MapFromArrays" , { parquetOutputLoc => { spark =>
1154
+ import spark .implicits ._
1155
+ val df1 = Seq ((Array (" a" , " b" , " c" ), Array (1 , 2 , 3 )),
1156
+ (Array (" x" , " y" , " z" ), Array (10 , 20 , 30 ))).toDF(" keys" , " values" )
1157
+ // write df1 to parquet to transform LocalTableScan to ProjectExec
1158
+ val df2 = writeAndReadParquet(spark, df1, s " $parquetOutputLoc/mapfromarraystesttext " )
1159
+ // map_from_arrays should be part of ProjectExec
1160
+ df2.select(map_from_arrays(df2(" keys" ), df2(" values" )).as(" map" ))
1161
+ }})
1162
+ )
1163
+
1164
+ /**
1165
+ * Tests whether a given Spark SQL expression is supported in ProjectExec.
1166
+ *
1167
+ * @param appName Name of the Spark application.
1168
+ * @param fileToSparkSession Function that maps a temporary Parquet directory to a function that
1169
+ * takes a SparkSession and returns a DataFrame.
1170
+ * @param parquetOutputLoc Temporary directory used for writing and reading Parquet files.
1171
+ */
1172
+ private def testExpressionInProjectExec (appName : String ,
1173
+ fileToSparkSession : File => (SparkSession => DataFrame ),
1174
+ parquetOutputLoc : File ): Unit = {
1175
+ TrampolineUtil .withTempDir { eventLogDir =>
1176
+ val (eventLog, _) =
1177
+ ToolTestUtils .generateEventLog(eventLogDir, appName)(fileToSparkSession(parquetOutputLoc))
1178
+ val pluginTypeChecker = new PluginTypeChecker ()
1179
+ val app = createAppFromEventlog(eventLog)
1180
+ assert(app.sqlPlans.size == 2 )
1181
+ val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
1182
+ SQLPlanParser .parseSQLPlan(app.appId, plan, sqlID, " " , pluginTypeChecker, app)
1183
+ }
1184
+ verifyExecToStageMapping(parsedPlans.toSeq, app)
1185
+ val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
1186
+ val wholeStages = allExecInfo.filter(_.exec.contains(" WholeStageCodegen" ))
1187
+ assert(wholeStages.size == 1 )
1188
+ assert(wholeStages.forall(_.duration.nonEmpty))
1189
+ val allChildren = wholeStages.flatMap(_.children).flatten
1190
+ val projects = allChildren.filter(_.exec == " Project" )
1191
+ assertSizeAndSupported(1 , projects)
1192
+ }
1166
1193
}
1167
1194
1168
- test(" Expressions supported in ProjectExec" ) {
1169
- TrampolineUtil .withTempDir { parquetoutputLoc =>
1170
- for ((appName, sparkSession) <- appNamesAndSparkSessions(parquetoutputLoc)) {
1171
- TrampolineUtil .withTempDir { eventLogDir =>
1172
- val (eventLog, _) = ToolTestUtils .generateEventLog(eventLogDir, appName)(sparkSession)
1173
- val pluginTypeChecker = new PluginTypeChecker ()
1174
- val app = createAppFromEventlog(eventLog)
1175
- assert(app.sqlPlans.size == 2 )
1176
- val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
1177
- SQLPlanParser .parseSQLPlan(app.appId, plan, sqlID, " " , pluginTypeChecker, app)
1178
- }
1179
- verifyExecToStageMapping(parsedPlans.toSeq, app)
1180
- val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
1181
- val wholeStages = allExecInfo.filter(_.exec.contains(" WholeStageCodegen" ))
1182
- assert(wholeStages.size == 1 )
1183
- assert(wholeStages.forall(_.duration.nonEmpty))
1184
- val allChildren = wholeStages.flatMap(_.children).flatten
1185
- val projects = allChildren.filter(_.exec == " Project" )
1186
- assertSizeAndSupported(1 , projects)
1187
- }
1195
+ forAll(projectExecTestCases) { (exprName, fileToSparkSession) =>
1196
+ test(s " $exprName is supported in ProjectExec " ) {
1197
+ TrampolineUtil .withTempDir { parquetOutputLoc =>
1198
+ testExpressionInProjectExec(s " { $exprName}SupportedInProjectExec " , fileToSparkSession,
1199
+ parquetOutputLoc)
1188
1200
}
1189
1201
}
1190
1202
}
0 commit comments