@@ -18,7 +18,7 @@ package org.apache.gluten.execution
18
18
19
19
import org .apache .gluten .GlutenConfig
20
20
import org .apache .gluten .benchmarks .GenTPCDSTableScripts
21
- import org .apache .gluten .utils .UTSystemParameters
21
+ import org .apache .gluten .utils .{ Arm , UTSystemParameters }
22
22
23
23
import org .apache .spark .SparkConf
24
24
import org .apache .spark .internal .Logging
@@ -46,8 +46,8 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
46
46
rootPath + " ../../../../gluten-core/src/test/resources/tpcds-queries/tpcds.queries.original"
47
47
protected val queriesResults : String = rootPath + " tpcds-decimal-queries-output"
48
48
49
- /** Return values: (sql num, is fall back, skip fall back assert ) */
50
- def tpcdsAllQueries (isAqe : Boolean ): Seq [(String , Boolean , Boolean )] =
49
+ /** Return values: (sql num, is fall back) */
50
+ def tpcdsAllQueries (isAqe : Boolean ): Seq [(String , Boolean )] =
51
51
Range
52
52
.inclusive(1 , 99 )
53
53
.flatMap(
@@ -57,37 +57,37 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
57
57
} else {
58
58
Seq (" q" + " %d" .format(queryNum))
59
59
}
60
- val noFallBack = queryNum match {
61
- case i if ! isAqe && (i == 10 || i == 16 || i == 35 || i == 94 ) =>
62
- // q10 smj + existence join
63
- // q16 smj + left semi + not condition
64
- // q35 smj + existence join
65
- // Q94 BroadcastHashJoin, LeftSemi, NOT condition
66
- (false , false )
67
- case i if isAqe && (i == 16 || i == 94 ) =>
68
- (false , false )
69
- case other => (true , false )
70
- }
71
- sqlNums.map((_, noFallBack._1, noFallBack._2))
60
+ val native = ! fallbackSets(isAqe).contains(queryNum)
61
+ sqlNums.map((_, native))
72
62
})
73
63
74
- // FIXME "q17", stddev_samp inconsistent results, CH return NaN, Spark return null
64
+ protected def fallbackSets (isAqe : Boolean ): Set [Int ] = {
65
+ val more = if (isSparkVersionGE(" 3.5" )) Set (44 , 67 , 70 ) else Set .empty[Int ]
66
+
67
+ // q16 smj + left semi + not condition
68
+ // Q94 BroadcastHashJoin, LeftSemi, NOT condition
69
+ if (isAqe) {
70
+ Set (16 , 94 ) | more
71
+ } else {
72
+ // q10, q35 smj + existence join
73
+ Set (10 , 16 , 35 , 94 ) | more
74
+ }
75
+ }
75
76
protected def excludedTpcdsQueries : Set [String ] = Set (
76
- " q61" , // inconsistent results
77
- " q66" , // inconsistent results
78
- " q67" // inconsistent results
77
+ " q66" // inconsistent results
79
78
)
80
79
81
80
def executeTPCDSTest (isAqe : Boolean ): Unit = {
82
81
tpcdsAllQueries(isAqe).foreach(
83
82
s =>
84
83
if (excludedTpcdsQueries.contains(s._1)) {
85
84
ignore(s " TPCDS ${s._1.toUpperCase()}" ) {
86
- runTPCDSQuery(s._1, noFallBack = s._2, skipFallBackAssert = s._3 ) { df => }
85
+ runTPCDSQuery(s._1, noFallBack = s._2) { df => }
87
86
}
88
87
} else {
89
- test(s " TPCDS ${s._1.toUpperCase()}" ) {
90
- runTPCDSQuery(s._1, noFallBack = s._2, skipFallBackAssert = s._3) { df => }
88
+ val tag = if (s._2) " Native" else " Fallback"
89
+ test(s " TPCDS[ $tag] ${s._1.toUpperCase()}" ) {
90
+ runTPCDSQuery(s._1, noFallBack = s._2) { df => }
91
91
}
92
92
})
93
93
}
@@ -152,7 +152,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
152
152
}
153
153
154
154
override protected def afterAll (): Unit = {
155
- ClickhouseSnapshot .clearAllFileStatusCache
155
+ ClickhouseSnapshot .clearAllFileStatusCache()
156
156
DeltaLog .clearCache()
157
157
158
158
try {
@@ -183,11 +183,10 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
183
183
tpcdsQueries : String = tpcdsQueries,
184
184
queriesResults : String = queriesResults,
185
185
compareResult : Boolean = true ,
186
- noFallBack : Boolean = true ,
187
- skipFallBackAssert : Boolean = false )(customCheck : DataFrame => Unit ): Unit = {
186
+ noFallBack : Boolean = true )(customCheck : DataFrame => Unit ): Unit = {
188
187
189
188
val sqlFile = tpcdsQueries + " /" + queryNum + " .sql"
190
- val sql = Source .fromFile(new File (sqlFile), " UTF-8" ).mkString
189
+ val sql = Arm .withResource( Source .fromFile(new File (sqlFile), " UTF-8" ))(_ .mkString)
191
190
val df = spark.sql(sql)
192
191
193
192
if (compareResult) {
@@ -212,13 +211,13 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
212
211
// using WARN to guarantee printed
213
212
log.warn(s " query: $queryNum, finish comparing with saved result " )
214
213
} else {
215
- val start = System .currentTimeMillis();
214
+ val start = System .currentTimeMillis()
216
215
val ret = df.collect()
217
216
// using WARN to guarantee printed
218
217
log.warn(s " query: $queryNum skipped comparing, time cost to collect: ${System
219
218
.currentTimeMillis() - start} ms, ret size: ${ret.length}" )
220
219
}
221
- WholeStageTransformerSuite .checkFallBack(df, noFallBack, skipFallBackAssert )
220
+ WholeStageTransformerSuite .checkFallBack(df, noFallBack)
222
221
customCheck(df)
223
222
}
224
223
}
0 commit comments