diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 5251bf49390..918f981975f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,13 @@ package com.nvidia.spark.rapids -import java.time.ZoneId - -import scala.collection.mutable - +import com.nvidia.spark.rapids.RapidsMeta.noNeedToReplaceReason import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl} +import java.time.ZoneId +import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, UnaryLike} @@ -64,6 +63,8 @@ final class NoRuleDataFromReplacementRule extends DataFromReplacementRule { object RapidsMeta { val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported") + + def noNeedToReplaceReason(klass: Class[_]) = s"there is no need to replace $klass" } /** @@ -936,7 +937,7 @@ final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan]( override def suppressWillWorkOnGpuInfo: Boolean = true override def tagPlanForGpu(): Unit = - willNotWorkOnGpu(s"there is no need to replace ${plan.getClass}") + willNotWorkOnGpu(noNeedToReplaceReason(plan.getClass)) override def convertToGpu(): GpuExec = throw new IllegalStateException("Cannot be converted to GPU") diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala index 213f1205db7..7d4a2ab621f 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.RapidsMeta.noNeedToReplaceReason import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.adaptive._ @@ -86,4 +87,18 @@ class GpuCustomShuffleReaderMeta(reader: AQEShuffleReadExec, shuffleEx.getTagValue(GpuShuffleMetaBase.availableRuntimeDataTransition) .getOrElse(false) } + + override def checkExistingTags(): Unit = { + // Some rules perform a transform and may replace ShuffleQueryStageExec + // with CustomShuffleReaderExec, causing tags to be copied from ShuffleQueryStageExec to + // CustomShuffleReaderExec, including the "no need to replace ShuffleQueryStageExec" tag. + + val noNeedReason = noNeedToReplaceReason(classOf[ShuffleQueryStageExec]) + + wrapped.getTagValue(RapidsMeta.gpuSupportedTag) + .foreach(_.diff(cannotBeReplacedReasons.get) + .filterNot(s => noNeedReason.equals(s)) + .foreach(willNotWorkOnGpu)) + } + } \ No newline at end of file