Skip to content

Commit

Permalink
Fix issue with CustomerShuffleReaderExec metadata copy (#11917)
Browse files Browse the repository at this point in the history
* Fix nvbugs 5028393

Signed-off-by: Renjie Liu <[email protected]>

* Fix license header

---------

Signed-off-by: Renjie Liu <[email protected]>
  • Loading branch information
liurenjie1024 authored Jan 7, 2025
1 parent 4df6d60 commit 50b14de
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,14 +16,13 @@

package com.nvidia.spark.rapids

import java.time.ZoneId

import scala.collection.mutable

import com.nvidia.spark.rapids.RapidsMeta.noNeedToReplaceReason
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl}
import java.time.ZoneId
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, UnaryLike}
Expand Down Expand Up @@ -64,6 +63,8 @@ final class NoRuleDataFromReplacementRule extends DataFromReplacementRule {

object RapidsMeta {
val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported")

def noNeedToReplaceReason(klass: Class[_]) = s"there is no need to replace $klass"
}

/**
Expand Down Expand Up @@ -936,7 +937,7 @@ final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan](
override def suppressWillWorkOnGpuInfo: Boolean = true

override def tagPlanForGpu(): Unit =
willNotWorkOnGpu(s"there is no need to replace ${plan.getClass}")
willNotWorkOnGpu(noNeedToReplaceReason(plan.getClass))

override def convertToGpu(): GpuExec =
throw new IllegalStateException("Cannot be converted to GPU")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsMeta.noNeedToReplaceReason

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.adaptive._
Expand Down Expand Up @@ -86,4 +87,18 @@ class GpuCustomShuffleReaderMeta(reader: AQEShuffleReadExec,
shuffleEx.getTagValue(GpuShuffleMetaBase.availableRuntimeDataTransition)
.getOrElse(false)
}

override def checkExistingTags(): Unit = {
// Some rules perform a transform and may replace ShuffleQueryStageExec
// with CustomShuffleReaderExec, causing tags to be copied from ShuffleQueryStageExec to
// CustomShuffleReaderExec, including the "no need to replace ShuffleQueryStageExec" tag.

val noNeedReason = noNeedToReplaceReason(classOf[ShuffleQueryStageExec])

wrapped.getTagValue(RapidsMeta.gpuSupportedTag)
.foreach(_.diff(cannotBeReplacedReasons.get)
.filterNot(s => noNeedReason.equals(s))
.foreach(willNotWorkOnGpu))
}

}

0 comments on commit 50b14de

Please sign in to comment.