Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix two potential OOM issues in GPU aggregate. #11908

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,11 @@ object DataTypeUtils {
case _ => false
}

def hasOffset(dataType: DataType): Boolean = dataType match {
case _: ArrayType | _: StringType | _: BinaryType => true
case _ => false
}

def hasNestedTypes(schema: StructType): Boolean =
schema.exists(f => isNestedType(f.dataType))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1095,16 +1095,13 @@ class GpuMergeAggregateIterator(
closeOnExcept(new ArrayBuffer[AutoClosableArrayBuffer[SpillableColumnarBatch]]) {
toAggregateBuckets =>
var currentSize = 0L
while (batchesByBucket.nonEmpty &&
(
// for some test cases targetMergeBatchSize is too small to fit any bucket,
// in this case we put the first bucket into toAggregateBuckets anyway
// refer to https://github.com/NVIDIA/spark-rapids/issues/11790 for examples
toAggregateBuckets.isEmpty ||
batchesByBucket.last.size() + currentSize <= targetMergeBatchSize)) {
val bucket = batchesByBucket.remove(batchesByBucket.size - 1)
currentSize += bucket.map(_.sizeInBytes).sum
toAggregateBuckets += bucket
var keepGoing = true
while (batchesByBucket.nonEmpty && keepGoing) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may need use a separate PR for this as it is irrelevant to the description in #11903

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for review. @binmahone Could you help file an issue for this? Then I will follow your suggestion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also consider modifying the description in 11903 :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

currentSize += batchesByBucket.last.map(_.sizeInBytes).sum
keepGoing = currentSize <= targetMergeBatchSize || toAggregateBuckets.isEmpty
if (keepGoing) {
toAggregateBuckets += batchesByBucket.remove(batchesByBucket.size - 1)
}
}

AggregateUtils.concatenateAndMerge(
Expand Down Expand Up @@ -2225,4 +2222,4 @@ class DynamicGpuPartialAggregateIterator(
throw new NoSuchElementException()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@ import ai.rapids.cudf
import ai.rapids.cudf._
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.DataTypeUtils.hasOffset
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit}
Expand All @@ -35,11 +36,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, RangePartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SampleExec, SparkPlan}
import org.apache.spark.sql.rapids.{GpuPartitionwiseSampledRDD, GpuPoissonSampler}
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{DataType, LongType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.random.BernoulliCellSampler

class GpuProjectExecMeta(
Expand Down Expand Up @@ -308,12 +311,176 @@ object PreProjectSplitIterator {
} else {
boundExprs.getPassThroughIndex(index).map { inputIndex =>
cb.column(inputIndex).asInstanceOf[GpuColumnVector].getBase.getDeviceMemorySize
}.orElse {
// A literal has an exact size that should be taken into account.
extractGpuLit(boundExprs.exprTiers.last(index)).map { gpuLit =>
calcSizeForLiteral(gpuLit.value, gpuLit.dataType, numRows)
}
}.getOrElse {
GpuBatchUtils.minGpuMemory(dataType, true, numRows)
}
}
}.sum
}

@scala.annotation.tailrec
def extractGpuLit(exp: Expression): Option[GpuLiteral] = exp match {
case gl: GpuLiteral => Some(gl)
case ga: GpuAlias => extractGpuLit(ga.child)
case _ => None
}

private[rapids] def calcSizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = {
// First calculate the meta buffers size
val metaSize = new LitMetaCollector(litVal, litType).collect.map { litMeta =>
val expandedRowsNum = litMeta.getRowsNum * numRows
var totalSize = 0L
if (litMeta.hasNull) {
totalSize += GpuBatchUtils.calculateValidityBufferSize(expandedRowsNum)
}
if (litMeta.hasOffset) {
totalSize += GpuBatchUtils.calculateOffsetBufferSize(expandedRowsNum)
}
totalSize
}.sum
// finalSize = oneLitValueSize * numRows + metadata size
calcLitValueSize(litVal, litType) * numRows + metaSize
}

/**
* Represent the metadata information of a literal or one of its children,
* which will be used to calculate the final metadata size after expanding
* this literal to a column.
*/
private class LitMeta(val hasNull: Boolean, val hasOffset: Boolean) {
private var rowsNum: Int = 0
def incRowsNum(rows: Int = 1): Unit = rowsNum += rows
def getRowsNum: Int = rowsNum

override def toString: String =
s"LitMeta{rowsNum: $rowsNum, hasNull: $hasNull, hasOffset: $hasOffset}"
}

/**
* Collect the metadata information of a literal, the result also includes
* its children for a nested type literal.
*/
private class LitMetaCollector(litValue: Any, litType: DataType) {
private var collected = false
private val metaInfos: ArrayBuffer[LitMeta] = ArrayBuffer.empty

def collect: Seq[LitMeta] = {
if (!collected) {
executeCollect(litValue, litType, litValue == null, 0)
collected = true
}
metaInfos.filter(_ != null).toSeq
}

/**
* Go through the literal and all its children to collect the meta information and
* save to the cache, call "collect" to get the result.
* Each LitMeta indicates whether the literal or a child will has offset/validity
* buffers after being expanded to a column, along with the number of original rows.
* For nested types, it follows the type definition from
* https://github.com/rapidsai/cudf/blob/a0487be669326175982c8bfcdab4d61184c88e27/
* cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md#list-columns
*/
private def executeCollect(lit: Any, litTp: DataType, nullable: Boolean,
depth: Int): Unit = {
litTp match {
case ArrayType(elemType, hasNullElem) =>
// It may be at a middle element of a nested array, so use the nullable
// from the parent.
getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum()
// Go into the child
val arrayData = lit.asInstanceOf[ArrayData]
if (arrayData != null) { // Only need to go into child when nonempty
(0 until arrayData.numElements()).foreach(i =>
executeCollect(arrayData.get(i, elemType), elemType, hasNullElem, depth + 1)
)
}
case StructType(fields) =>
if (nullable) {
// Add a meta for only a nullable struct, and a struct doesn't have offsets.
getOrInitAt(depth, new LitMeta(nullable, false)).incRowsNum()
}
// Always go into children, which is different from array.
val stData = lit.asInstanceOf[InternalRow]
fields.zipWithIndex.foreach { case (f, i) =>
val fLit = if (stData != null) stData.get(i, f.dataType) else null
executeCollect(fLit, f.dataType, f.nullable, depth + 1 + i)
}
case MapType(keyType, valType, hasNullValue) =>
// Map is list of struct in cudf. But the nested struct has no offset or
// validity, so only need a meta for the top list.
getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum()
val mapData = lit.asInstanceOf[MapData]
if (mapData != null) {
mapData.foreach(keyType, valType, { case (key, value) =>
executeCollect(key, keyType, false, depth + 1)
executeCollect(value, valType, hasNullValue, depth + 2)
})
}
case otherType => // primitive types
val hasOffsetBuf = hasOffset(otherType)
if (nullable || hasOffsetBuf) {
getOrInitAt(depth, new LitMeta(nullable, hasOffsetBuf)).incRowsNum()
}
}
}

private def getOrInitAt(pos: Int, initMeta: LitMeta): LitMeta = {
if (pos >= metaInfos.length) {
(metaInfos.length until pos).foreach { _ =>
metaInfos.append(null)
}
metaInfos.append(initMeta)
} else if (metaInfos(pos) == null) {
metaInfos(pos) = initMeta
}
metaInfos(pos)
}
}

private def calcLitValueSize(lit: Any, litTp: DataType): Long = {
litTp match {
case StringType =>
if (lit == null) 0L else lit.asInstanceOf[UTF8String].numBytes()
case BinaryType =>
if (lit == null) 0L else lit.asInstanceOf[Array[Byte]].length
case ArrayType(elemType, _) =>
val arrayData = lit.asInstanceOf[ArrayData]
if (arrayData == null) {
0L
} else {
(0 until arrayData.numElements()).map { idx =>
calcLitValueSize(arrayData.get(idx, elemType), elemType)
}.sum
}
case MapType(keyType, valType, _) =>
val mapData = lit.asInstanceOf[MapData]
if (mapData == null) {
0L
} else {
val keyData = mapData.keyArray()
val valData = mapData.valueArray()
(0 until mapData.numElements()).map { i =>
calcLitValueSize(keyData.get(i, keyType), keyType) +
calcLitValueSize(valData.get(i, valType), valType)
}.sum
}
case StructType(fields) =>
// A special case that it should always go into children even lit is null.
// Because the children of fixed width will always take some memory.
val stData = lit.asInstanceOf[InternalRow]
fields.zipWithIndex.map { case (f, i) =>
val fLit = if (stData == null) null else stData.get(i, f.dataType)
calcLitValueSize(fLit, f.dataType)
}.sum
case _ => litTp.defaultSize
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.unit

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{GpuScalar, GpuUnitTests, PreProjectSplitIterator}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/** These tests only cover nested type literals for the PreProjectSplitIterator case */
class LiteralSizeEstimationTest extends GpuUnitTests {
private val numRows = 1000

private def testLiteralSizeEstimate(lit: Any, litType: DataType): Unit = {
val col = withResource(GpuScalar.from(lit, litType))(ColumnVector.fromScalar(_, numRows))
val actualSize = withResource(col)(_.getDeviceMemorySize)
val estimatedSize = PreProjectSplitIterator.calcSizeForLiteral(lit, litType, numRows)
assertResult(actualSize)(estimatedSize)
}

test("estimate the array(int) literal size") {
val litType = ArrayType(IntegerType, true)
val lit = ArrayData.toArrayData(Array(null, 1, 2, null))
testLiteralSizeEstimate(lit, litType)
}

test("estimate the array(string) literal size") {
val litType = ArrayType(StringType, true)
val lit = ArrayData.toArrayData(
Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2")))
testLiteralSizeEstimate(lit, litType)
}

test("estimate the array(array(array(int))) literal size") {
val litType = ArrayType(ArrayType(ArrayType(IntegerType, true), true), true)
val nestedElem1 = ArrayData.toArrayData(Array(null, 1, 2, null))
val nestedElem2 = ArrayData.toArrayData(Array(null))
val nestedElem3 = ArrayData.toArrayData(Array())
val elem1 = ArrayData.toArrayData(Array(nestedElem1, null))
val elem2 = ArrayData.toArrayData(Array(nestedElem2, null, nestedElem3))
val lit = ArrayData.toArrayData(Array(null, elem1, null, elem2, null))
testLiteralSizeEstimate(lit, litType)
}

test("estimate the array(array(array(string))) literal size") {
val litType = ArrayType(ArrayType(ArrayType(StringType, true), true), true)
val nestedElem1 = ArrayData.toArrayData(
Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2")))
val nestedElem2 = ArrayData.toArrayData(Array(null))
val nestedElem3 = ArrayData.toArrayData(Array())
val elem1 = ArrayData.toArrayData(Array(nestedElem1, null))
val elem2 = ArrayData.toArrayData(Array(nestedElem2, null, nestedElem3))
val lit = ArrayData.toArrayData(Array(null, elem1, null, elem2, null))
testLiteralSizeEstimate(lit, litType)
}

test("estimate the struct(int, string) literal size") {
val litType = StructType(Seq(
StructField("int1", IntegerType),
StructField("string2", StringType)
))
// null
testLiteralSizeEstimate(InternalRow(null, null), litType)
// normal case
testLiteralSizeEstimate(InternalRow(1, UTF8String.fromString("s1")), litType)
}

test("estimate the struct(int, array(string)) literal size") {
val litType = StructType(Seq(
StructField("int1", IntegerType),
StructField("string2", ArrayType(StringType, true))
))
testLiteralSizeEstimate(InternalRow(null, null), litType)
val arrayLit = ArrayData.toArrayData(
Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2")))
// normal case
testLiteralSizeEstimate(InternalRow(1, arrayLit), litType)
}

test("estimate the list(struct(int, array(string))) literal size") {
val litType = ArrayType(
StructType(Seq(
StructField("int1", IntegerType),
StructField("string2", ArrayType(StringType, true))
)), true)
val arrayLit = ArrayData.toArrayData(
Array(null, UTF8String.fromString("a1"), UTF8String.fromString("a2")))
val elem1 = InternalRow(1, arrayLit)
val elem2 = InternalRow(null, null)
val lit = ArrayData.toArrayData(Array(null, elem1, elem2))
testLiteralSizeEstimate(lit, litType)
}

test("estimate the map(int, array(string)) literal size") {
val litType = MapType(IntegerType, ArrayType(StringType, true), true)
val arrayLit = ArrayData.toArrayData(
Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2")))
val valueLit = ArrayData.toArrayData(Array(null, arrayLit))
val keyLit = ArrayData.toArrayData(Array(1, 2))
val lit = new ArrayBasedMapData(keyLit, valueLit)
testLiteralSizeEstimate(lit, litType)
}
}
Loading