Skip to content

Commit

Permalink
Support group-limit optimization for ROW_NUMBER in Qualification (#1487)
Browse files Browse the repository at this point in the history
* Support group-limit optimization for ROW_NUMBER in Qualification

Fixes #1484

This code adds the `row_number` to the list of supported expressions in
`WindowGroupLimit`.
Update the unit-tests to verify that the behavior is as expected.

---------

Signed-off-by: Ahmed Hussein (amahussein) <[email protected]>
  • Loading branch information
amahussein authored Jan 6, 2025
1 parent 12582b6 commit 87f7e65
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,9 +27,7 @@ case class WindowGroupLimitParser(
sqlID: Long) extends ExecParser {

val fullExecName: String = node.name + "Exec"
// row_number() is currently not supported by the plugin (v24.04)
// Ref: https://github.com/NVIDIA/spark-rapids/pull/10500
val supportedRankingExprs = Set("rank", "dense_rank")
val supportedRankingExprs = Set("rank", "dense_rank", "row_number")

private def validateRankingExpr(rankingExprs: Array[String]): Boolean = {
rankingExprs.length == 1 && supportedRankingExprs.contains(rankingExprs.head)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1761,62 +1761,24 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
pluginTypeChecker.getNotSupportedExprs(filterExprArray) shouldBe 'empty
}

runConditionalTest("WindowGroupLimitExec is supported", execsSupportedSparkGTE350) {
/**
* Helper method to run tests for WindowGroupLimit expressions
* @param windowExpr the expression to be evaluated (i.e., rank, row_number).
* @param skipSqlID the SQL ID to skip the stage verification.
*/
def runWindowGroupLimitTest(windowExpr: String, skipSqlID: Long): Unit = {
val windowGroupLimitExecCmd = "WindowGroupLimit"
val tbl_name = "foobar_tbl"
val tbl_name = s"foobar_tbl_test_$windowExpr"
val appName = s"WindowGroupLimitExecTest_$windowExpr"
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir,
windowGroupLimitExecCmd) { spark =>
withTable(spark, tbl_name) {
spark.sql(s"CREATE TABLE $tbl_name (foo STRING, bar STRING) USING PARQUET")
val query =
s"""
SELECT foo, bar FROM (
SELECT foo, bar,
RANK() OVER (PARTITION BY foo ORDER BY bar) as rank
FROM $tbl_name)
WHERE rank <= 2"""
spark.sql(query)
}
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app)
}
// Note that the generated plan, there are skipped stages that causes some execs to appear
// without their relevant stages. so we skip the stage verification here.
verifyExecToStageMapping(parsedPlans.toSeq, app, Some( planInfo =>
if (planInfo.sqlID == 73) {
// Nodes should not have any stages
val allExecInfos = planInfo.execInfo.flatMap { e =>
e.children.getOrElse(Seq.empty) :+ e
}
// exclude all stages higher than 8 because those ones belong to a skipped stage
allExecInfos.filter(_.nodeId <= 8).forall(_.stages.nonEmpty)
})
)
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val windowGroupLimitExecs = allExecInfo.filter(_.exec.contains(windowGroupLimitExecCmd))
// We should have two WindowGroupLimitExec operators (Partial and Final).
assertSizeAndSupported(2, windowGroupLimitExecs)
}
}

runConditionalTest("row_number in WindowGroupLimitExec is not supported",
execsSupportedSparkGTE350) {
val windowGroupLimitExecCmd = "WindowGroupLimit"
val tbl_name = "foobar_tbl"
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir,
windowGroupLimitExecCmd) { spark =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, appName) { spark =>
withTable(spark, tbl_name) {
spark.sql(s"CREATE TABLE $tbl_name (foo STRING, bar STRING) USING PARQUET")
val query =
s"""
SELECT foo, bar FROM (
SELECT foo, bar,
ROW_NUMBER() OVER (PARTITION BY foo ORDER BY bar) as rank
$windowExpr() OVER (PARTITION BY foo ORDER BY bar) as rank
FROM $tbl_name)
WHERE rank <= 2"""
spark.sql(query)
Expand All @@ -1830,7 +1792,7 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
// Note that the generated plan, there are skipped stages that causes some execs to appear
// without their relevant stages. so we skip the stage verification here.
verifyExecToStageMapping(parsedPlans.toSeq, app, Some( planInfo =>
if (planInfo.sqlID == 76) {
if (planInfo.sqlID == skipSqlID) {
// Nodes should not have any stages
val allExecInfos = planInfo.execInfo.flatMap { e =>
e.children.getOrElse(Seq.empty) :+ e
Expand All @@ -1840,18 +1802,22 @@ class SQLPlanParserSuite extends BasePlanParserSuite {
})
)
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val windowExecNotSupportedExprs = allExecInfo.filter(
_.exec.contains(windowGroupLimitExecCmd)).flatMap(x => x.unsupportedExprs)
windowExecNotSupportedExprs.head.getOpName shouldEqual "row_number"
windowExecNotSupportedExprs.head.unsupportedReason shouldEqual
"Ranking function row_number is not supported in WindowGroupLimitExec"
val windowGroupLimitExecs = allExecInfo.filter(_.exec.contains(windowGroupLimitExecCmd))
// We should have two WindowGroupLimitExec operators (Partial and Final) which are
// not supported due to unsupported expression.
assertSizeAndNotSupported(2, windowGroupLimitExecs)
// We should have two WindowGroupLimitExec operators (Partial and Final).
assertSizeAndSupported(2, windowGroupLimitExecs)
}
}

runConditionalTest("WindowGroupLimit expression rank is supported",
execsSupportedSparkGTE350) {
runWindowGroupLimitTest("RANK", skipSqlID = 73)
}

runConditionalTest("WindowGroupLimit expression row_number is supported",
execsSupportedSparkGTE350) {
runWindowGroupLimitTest("ROW_NUMBER", skipSqlID = 76)
}

runConditionalTest("CheckOverflowInsert should not exist in Physical Plan",
execsSupportedSparkGTE331) {
// This test verifies that the 'CheckOverflowInsert' expression exists in the logical plan
Expand Down

0 comments on commit 87f7e65

Please sign in to comment.