From c580b854650a87735113f30a0b1f8b2c9858b0ba Mon Sep 17 00:00:00 2001 From: Firestarman Date: Thu, 2 Jan 2025 14:49:35 +0800 Subject: [PATCH 1/5] Fix two potential OOM issues in agg. The first one is by taking the nested literals into account when calculating the output size for pre-split. The second one is by using the correct size for buffer size comparison when collecting the next bundle of batches in aggregate. Signed-off-by: Firestarman --- .../nvidia/spark/rapids/DataTypeUtils.scala | 7 +- .../spark/rapids/GpuAggregateExec.scala | 21 ++--- .../spark/rapids/basicPhysicalOperators.scala | 92 ++++++++++++++++++- 3 files changed, 105 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala index a031a2aaeed..e3d71818315 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,11 @@ object DataTypeUtils { case _ => false } + def hasOffset(dataType: DataType): Boolean = dataType match { + case _: ArrayType | _: StringType | _: BinaryType => true + case _ => false + } + def hasNestedTypes(schema: StructType): Boolean = schema.exists(f => isNestedType(f.dataType)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 8fc5326705e..847a21d81e6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1095,16 +1095,13 @@ class GpuMergeAggregateIterator( closeOnExcept(new ArrayBuffer[AutoClosableArrayBuffer[SpillableColumnarBatch]]) { toAggregateBuckets => var currentSize = 0L - while (batchesByBucket.nonEmpty && - ( - // for some test cases targetMergeBatchSize is too small to fit any bucket, - // in this case we put the first bucket into toAggregateBuckets anyway - // refer to https://github.com/NVIDIA/spark-rapids/issues/11790 for examples - toAggregateBuckets.isEmpty || - batchesByBucket.last.size() + currentSize <= targetMergeBatchSize)) { - val bucket = batchesByBucket.remove(batchesByBucket.size - 1) - currentSize += bucket.map(_.sizeInBytes).sum - toAggregateBuckets += bucket + var keepGoing = true + while (batchesByBucket.nonEmpty && keepGoing) { + currentSize += batchesByBucket.last.map(_.sizeInBytes).sum + keepGoing = currentSize <= targetMergeBatchSize || toAggregateBuckets.isEmpty + if (keepGoing) { + toAggregateBuckets += batchesByBucket.remove(batchesByBucket.size - 1) + } } AggregateUtils.concatenateAndMerge( @@ -2225,4 +2222,4 @@ class DynamicGpuPartialAggregateIterator( throw new NoSuchElementException() } } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 891e837d7e1..29e0a56e5dc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import ai.rapids.cudf import ai.rapids.cudf._ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.DataTypeUtils.hasOffset import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit} @@ -35,11 +36,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, RangePartitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SampleExec, SparkPlan} import org.apache.spark.sql.rapids.{GpuPartitionwiseSampledRDD, GpuPoissonSampler} import org.apache.spark.sql.rapids.execution.TrampolineUtil -import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.BernoulliCellSampler class GpuProjectExecMeta( @@ -308,12 +311,97 @@ object PreProjectSplitIterator { } else { boundExprs.getPassThroughIndex(index).map { inputIndex => cb.column(inputIndex).asInstanceOf[GpuColumnVector].getBase.getDeviceMemorySize + }.orElse { + // A literal has an exact size that should be taken into account. + extractGpuLit(boundExprs.exprTiers.last(index)).map { gpuLit => + calcMemorySizeForLiteral(gpuLit.value, gpuLit.dataType, numRows) + } }.getOrElse { GpuBatchUtils.minGpuMemory(dataType, true, numRows) } } }.sum } + + private def calcMemorySizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = { + // One literal size = value size + additional buffers size (offset and/or validity) + val litSize = calcLitValueSize(litVal, litType) + { + val pickRowNum: Int => Int = rowNum => if (litVal == null) 0 else rowNum + litType match { + case ArrayType(elemType, hasNullElem) => + val numElems = pickRowNum(litVal.asInstanceOf[ArrayData].numElements()) + // A GPU array literal requires only one column as the child + estimateLitAdditionSize(hasNullElem, hasOffset(elemType), numElems) + case StructType(fields) => + val childrenNumRows = pickRowNum(1) + // A GPU struct literal requires "fields.size" columns as the children. + fields.map(f => + estimateLitAdditionSize(f.nullable, hasOffset(f.dataType), childrenNumRows) + ).sum + case MapType(keyType, valType, hasNullValue) => + val mapRowsNum = pickRowNum(litVal.asInstanceOf[MapData].numElements()) + // A GPU map literal requires 4 columns as the children. + // the key and value column, along with two wrapper columns as below + // " list " + estimateLitAdditionSize(false, hasOffset(keyType), mapRowsNum) + // key + estimateLitAdditionSize(hasNullValue, hasOffset(valType), mapRowsNum) + // value + estimateLitAdditionSize(false, false, mapRowsNum) + // struct + estimateLitAdditionSize(false, true, mapRowsNum) // top list + case _ => 0L // primitive types has no nested additional buffers + } + } + // totalSize = litSize * numRows + top additional buffers size after expanding to a column. + litSize * numRows + estimateLitAdditionSize(litVal == null, hasOffset(litType), numRows) + } + + private def estimateLitAdditionSize(hasNull: Boolean, hasOffset: Boolean, rows: Int): Long = { + // Additional buffers size, it is not nested for literals, + // so no need to do it recursively. + var totalSize = 0L + if (hasNull) { + totalSize += GpuBatchUtils.calculateValidityBufferSize(rows) + } + if (hasOffset) { + totalSize += GpuBatchUtils.calculateOffsetBufferSize(rows) + } + totalSize + } + + private def calcLitValueSize(lit: Any, litTp: DataType): Long = if (lit == null) { + if (GpuBatchUtils.isFixedWidth(litTp)) { + litTp.defaultSize + } else { + 0L + } + } else { + litTp match { + case StringType => lit.asInstanceOf[UTF8String].numBytes() + case BinaryType => lit.asInstanceOf[Array[Byte]].length + case ArrayType(elemType, _) => + lit.asInstanceOf[ArrayData].array.map(calcLitValueSize(_, elemType)).sum + case StructType(fields) => + val stData = lit.asInstanceOf[InternalRow] + fields.zipWithIndex.map { case (f, i) => + calcLitValueSize(stData.get(i, f.dataType), f.dataType) + }.sum + case MapType(keyType, valType, _) => + val mapData = lit.asInstanceOf[MapData] + val keyData = mapData.keyArray() + val valData = mapData.valueArray() + (0 until mapData.numElements()).map { i => + calcLitValueSize(keyData.get(i, keyType), keyType) + + calcLitValueSize(valData.get(i, valType), valType) + }.sum + case _ => litTp.defaultSize + } + } + + @scala.annotation.tailrec + def extractGpuLit(exp: Expression): Option[GpuLiteral] = exp match { + case gl: GpuLiteral => Some(gl) + case ga: GpuAlias => extractGpuLit(ga.child) + case _ => None + } } /** From 97ca823c9f9cf047dbb4b1aff32c129d4f398c81 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Wed, 8 Jan 2025 08:39:32 +0000 Subject: [PATCH 2/5] Address the comments Signed-off-by: Firestarman --- .../spark/rapids/basicPhysicalOperators.scala | 170 +++++++++++++----- 1 file changed, 126 insertions(+), 44 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 29e0a56e5dc..8b3465c2c52 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -314,7 +314,7 @@ object PreProjectSplitIterator { }.orElse { // A literal has an exact size that should be taken into account. extractGpuLit(boundExprs.exprTiers.last(index)).map { gpuLit => - calcMemorySizeForLiteral(gpuLit.value, gpuLit.dataType, numRows) + calcSizeForLiteral(gpuLit.value, gpuLit.dataType, numRows) } }.getOrElse { GpuBatchUtils.minGpuMemory(dataType, true, numRows) @@ -323,48 +323,134 @@ object PreProjectSplitIterator { }.sum } - private def calcMemorySizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = { - // One literal size = value size + additional buffers size (offset and/or validity) - val litSize = calcLitValueSize(litVal, litType) + { - val pickRowNum: Int => Int = rowNum => if (litVal == null) 0 else rowNum - litType match { + @scala.annotation.tailrec + def extractGpuLit(exp: Expression): Option[GpuLiteral] = exp match { + case gl: GpuLiteral => Some(gl) + case ga: GpuAlias => extractGpuLit(ga.child) + case _ => None + } + + private def calcSizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = { + // First calculate the meta buffers size + val metaSize = new LitMetaCollector(litVal, litType).collect.map { litMeta => + val expandedRowsNum = litMeta.getRowsNum * numRows + var totalSize = 0L + if (litMeta.hasNull) { + totalSize += GpuBatchUtils.calculateValidityBufferSize(expandedRowsNum) + } + if (litMeta.hasOffset) { + totalSize += GpuBatchUtils.calculateOffsetBufferSize(expandedRowsNum) + } + totalSize + }.sum + // finalSize = oneLitValueSize * numRows + metadata size + calcLitValueSize(litVal, litType) * numRows + metaSize + } + + /** + * Represent the metadata information of a literal or one of its children, + * which will be used to calculate the final metadata size after expanding + * this literal to a column. + */ + private class LitMeta(val hasNull: Boolean, val hasOffset: Boolean) { + private var rowsNum: Int = 0 + def incRowsNum(rows: Int = 1): Unit = rowsNum += rows + def getRowsNum: Int = rowsNum + } + + /** + * Collect the metadata information of a literal, the result also includes + * its children for a nested type literal. + */ + private class LitMetaCollector(litValue: Any, litType: DataType) { + private var collected = false + private val metaInfos: ArrayBuffer[LitMeta] = ArrayBuffer.empty + + def collect: Seq[LitMeta] = { + if (!collected) { + executeCollect(litValue, litType, litValue == null, 0) + collected = true + } + metaInfos.filter(_ != null).toSeq + } + + /** + * Go through the literal and all its children to collect the meta information and + * save to the cache, call "collect" to get the result. + * Each LitMeta indicates whether the literal or a child will has offset/validity + * buffers after being expanded to a column, along with the number of original rows. + * For nested types, it follows the type definition from + * https://github.com/rapidsai/cudf/blob/a0487be669326175982c8bfcdab4d61184c88e27/ + * cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md#list-columns + */ + private def executeCollect(lit: Any, litTp: DataType, nullable: Boolean, + depth: Int): Unit = { + litTp match { case ArrayType(elemType, hasNullElem) => - val numElems = pickRowNum(litVal.asInstanceOf[ArrayData].numElements()) - // A GPU array literal requires only one column as the child - estimateLitAdditionSize(hasNullElem, hasOffset(elemType), numElems) + // It may be at a middle element of a nested array, so use the nullable + // from the parent. + getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum() + // Go into the child + val arrayData = lit.asInstanceOf[ArrayData] + if (arrayData == null || arrayData.numElements() <= 0) { + // Null or an empty list, still need to check the child meta + executeCollect(null, elemType, hasNullElem, depth + 1) + } else { // a nonempty list, normal case + (0 until arrayData.numElements()).foreach( i => + executeCollect(arrayData.get(i, elemType), elemType, hasNullElem, depth + 1) + ) + } case StructType(fields) => - val childrenNumRows = pickRowNum(1) - // A GPU struct literal requires "fields.size" columns as the children. - fields.map(f => - estimateLitAdditionSize(f.nullable, hasOffset(f.dataType), childrenNumRows) - ).sum + if (nullable && depth == 0) { + // Add a meta for only the top nullable struct according to + // the struct construction in cudf, see + // https://github.com/rapidsai/cudf/blob/v24.12.00/cpp/src/column/ + // column_factories.cu#L117, and + // https://github.com/rapidsai/cudf/blob/v24.12.00/java/src/main/ + // native/src/ColumnViewJni.cpp#L2574 + // and a struct doesnt' have offsets. + getOrInitAt(depth, new LitMeta(nullable, false)).incRowsNum() + } + // Go into children + fields.zipWithIndex.foreach { case (f, i) => + val fLit = if (lit == null) { + null + } else { + lit.asInstanceOf[InternalRow].get(i, f.dataType) + } + executeCollect(fLit, f.dataType, f.nullable, depth + 1 + i) + } case MapType(keyType, valType, hasNullValue) => - val mapRowsNum = pickRowNum(litVal.asInstanceOf[MapData].numElements()) - // A GPU map literal requires 4 columns as the children. - // the key and value column, along with two wrapper columns as below - // " list " - estimateLitAdditionSize(false, hasOffset(keyType), mapRowsNum) + // key - estimateLitAdditionSize(hasNullValue, hasOffset(valType), mapRowsNum) + // value - estimateLitAdditionSize(false, false, mapRowsNum) + // struct - estimateLitAdditionSize(false, true, mapRowsNum) // top list - case _ => 0L // primitive types has no nested additional buffers + // Map is list of struct in cudf. But the nested struct has no offset or + // validity, so only need a meta for the top list. + getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum() + val mapData = lit.asInstanceOf[MapData] + if (mapData == null || mapData.numElements() == 0) { + executeCollect(null, keyType, false, depth + 1) + executeCollect(null, valType, hasNullValue, depth + 2) + } else { + mapData.foreach(keyType, valType, { case (key, value) => + executeCollect(key, keyType, false, depth + 1) + executeCollect(value, valType, hasNullValue, depth + 2) + }) + } + case otherType => // primitive types + val hasOffset = hasOffset(otherType) + if (nullable || hasOffset) { + getOrInitAt(depth, new LitMeta(nullable, hasOffset)).incRowsNum() + } } } - // totalSize = litSize * numRows + top additional buffers size after expanding to a column. - litSize * numRows + estimateLitAdditionSize(litVal == null, hasOffset(litType), numRows) - } - private def estimateLitAdditionSize(hasNull: Boolean, hasOffset: Boolean, rows: Int): Long = { - // Additional buffers size, it is not nested for literals, - // so no need to do it recursively. - var totalSize = 0L - if (hasNull) { - totalSize += GpuBatchUtils.calculateValidityBufferSize(rows) - } - if (hasOffset) { - totalSize += GpuBatchUtils.calculateOffsetBufferSize(rows) + private def getOrInitAt(pos: Int, initMeta: LitMeta): LitMeta = { + if (pos >= metaInfos.length) { + (metaInfos.length until pos).foreach { _ => + metaInfos.append(null) + } + metaInfos.append(initMeta) + } + metaInfos(pos) } - totalSize } private def calcLitValueSize(lit: Any, litTp: DataType): Long = if (lit == null) { @@ -378,7 +464,10 @@ object PreProjectSplitIterator { case StringType => lit.asInstanceOf[UTF8String].numBytes() case BinaryType => lit.asInstanceOf[Array[Byte]].length case ArrayType(elemType, _) => - lit.asInstanceOf[ArrayData].array.map(calcLitValueSize(_, elemType)).sum + val arrayData = lit.asInstanceOf[ArrayData] + (0 until arrayData.numElements()).map(idx => + calcLitValueSize(arrayData.get(idx, elemType), elemType) + ).sum case StructType(fields) => val stData = lit.asInstanceOf[InternalRow] fields.zipWithIndex.map { case (f, i) => @@ -395,13 +484,6 @@ object PreProjectSplitIterator { case _ => litTp.defaultSize } } - - @scala.annotation.tailrec - def extractGpuLit(exp: Expression): Option[GpuLiteral] = exp match { - case gl: GpuLiteral => Some(gl) - case ga: GpuAlias => extractGpuLit(ga.child) - case _ => None - } } /** From 0a23ffeb522dd8d2549396602710b46fde4a7987 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Wed, 8 Jan 2025 16:47:58 +0800 Subject: [PATCH 3/5] Fix a build error Signed-off-by: Firestarman --- .../com/nvidia/spark/rapids/basicPhysicalOperators.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 8b3465c2c52..b2f1765b454 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -435,9 +435,9 @@ object PreProjectSplitIterator { }) } case otherType => // primitive types - val hasOffset = hasOffset(otherType) - if (nullable || hasOffset) { - getOrInitAt(depth, new LitMeta(nullable, hasOffset)).incRowsNum() + val hasOffsetBuf = hasOffset(otherType) + if (nullable || hasOffsetBuf) { + getOrInitAt(depth, new LitMeta(nullable, hasOffsetBuf)).incRowsNum() } } } From 626cccc26446d33acbbaaaae20eca2bbe8836997 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Wed, 8 Jan 2025 17:05:17 +0800 Subject: [PATCH 4/5] fix a potential NPE Signed-off-by: Firestarman --- .../nvidia/spark/rapids/basicPhysicalOperators.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index b2f1765b454..b6135a8ca79 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -448,8 +448,18 @@ object PreProjectSplitIterator { metaInfos.append(null) } metaInfos.append(initMeta) + initMeta + } else { + val meta = metaInfos(pos) + if (meta == null) { + metaInfos(pos) = initMeta + initMeta + } else { + meta + } + } - metaInfos(pos) + } } From af5b90e5923f381181c2d4753101e94810792f93 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Thu, 9 Jan 2025 18:17:10 +0800 Subject: [PATCH 5/5] add tests Signed-off-by: Firestarman --- .../spark/rapids/basicPhysicalOperators.scala | 99 +++++++------- .../unit/LiteralSizeEstimationTest.scala | 121 ++++++++++++++++++ 2 files changed, 164 insertions(+), 56 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/unit/LiteralSizeEstimationTest.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index b6135a8ca79..9330924e293 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -330,7 +330,7 @@ object PreProjectSplitIterator { case _ => None } - private def calcSizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = { + private[rapids] def calcSizeForLiteral(litVal: Any, litType: DataType, numRows: Int): Long = { // First calculate the meta buffers size val metaSize = new LitMetaCollector(litVal, litType).collect.map { litMeta => val expandedRowsNum = litMeta.getRowsNum * numRows @@ -356,6 +356,9 @@ object PreProjectSplitIterator { private var rowsNum: Int = 0 def incRowsNum(rows: Int = 1): Unit = rowsNum += rows def getRowsNum: Int = rowsNum + + override def toString: String = + s"LitMeta{rowsNum: $rowsNum, hasNull: $hasNull, hasOffset: $hasOffset}" } /** @@ -392,32 +395,20 @@ object PreProjectSplitIterator { getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum() // Go into the child val arrayData = lit.asInstanceOf[ArrayData] - if (arrayData == null || arrayData.numElements() <= 0) { - // Null or an empty list, still need to check the child meta - executeCollect(null, elemType, hasNullElem, depth + 1) - } else { // a nonempty list, normal case - (0 until arrayData.numElements()).foreach( i => + if (arrayData != null) { // Only need to go into child when nonempty + (0 until arrayData.numElements()).foreach(i => executeCollect(arrayData.get(i, elemType), elemType, hasNullElem, depth + 1) ) } case StructType(fields) => - if (nullable && depth == 0) { - // Add a meta for only the top nullable struct according to - // the struct construction in cudf, see - // https://github.com/rapidsai/cudf/blob/v24.12.00/cpp/src/column/ - // column_factories.cu#L117, and - // https://github.com/rapidsai/cudf/blob/v24.12.00/java/src/main/ - // native/src/ColumnViewJni.cpp#L2574 - // and a struct doesnt' have offsets. + if (nullable) { + // Add a meta for only a nullable struct, and a struct doesn't have offsets. getOrInitAt(depth, new LitMeta(nullable, false)).incRowsNum() } - // Go into children + // Always go into children, which is different from array. + val stData = lit.asInstanceOf[InternalRow] fields.zipWithIndex.foreach { case (f, i) => - val fLit = if (lit == null) { - null - } else { - lit.asInstanceOf[InternalRow].get(i, f.dataType) - } + val fLit = if (stData != null) stData.get(i, f.dataType) else null executeCollect(fLit, f.dataType, f.nullable, depth + 1 + i) } case MapType(keyType, valType, hasNullValue) => @@ -425,10 +416,7 @@ object PreProjectSplitIterator { // validity, so only need a meta for the top list. getOrInitAt(depth, new LitMeta(nullable, true)).incRowsNum() val mapData = lit.asInstanceOf[MapData] - if (mapData == null || mapData.numElements() == 0) { - executeCollect(null, keyType, false, depth + 1) - executeCollect(null, valType, hasNullValue, depth + 2) - } else { + if (mapData != null) { mapData.foreach(keyType, valType, { case (key, value) => executeCollect(key, keyType, false, depth + 1) executeCollect(value, valType, hasNullValue, depth + 2) @@ -448,48 +436,47 @@ object PreProjectSplitIterator { metaInfos.append(null) } metaInfos.append(initMeta) - initMeta - } else { - val meta = metaInfos(pos) - if (meta == null) { - metaInfos(pos) = initMeta - initMeta - } else { - meta - } - + } else if (metaInfos(pos) == null) { + metaInfos(pos) = initMeta } - + metaInfos(pos) } } - private def calcLitValueSize(lit: Any, litTp: DataType): Long = if (lit == null) { - if (GpuBatchUtils.isFixedWidth(litTp)) { - litTp.defaultSize - } else { - 0L - } - } else { + private def calcLitValueSize(lit: Any, litTp: DataType): Long = { litTp match { - case StringType => lit.asInstanceOf[UTF8String].numBytes() - case BinaryType => lit.asInstanceOf[Array[Byte]].length + case StringType => + if (lit == null) 0L else lit.asInstanceOf[UTF8String].numBytes() + case BinaryType => + if (lit == null) 0L else lit.asInstanceOf[Array[Byte]].length case ArrayType(elemType, _) => val arrayData = lit.asInstanceOf[ArrayData] - (0 until arrayData.numElements()).map(idx => - calcLitValueSize(arrayData.get(idx, elemType), elemType) - ).sum + if (arrayData == null) { + 0L + } else { + (0 until arrayData.numElements()).map { idx => + calcLitValueSize(arrayData.get(idx, elemType), elemType) + }.sum + } + case MapType(keyType, valType, _) => + val mapData = lit.asInstanceOf[MapData] + if (mapData == null) { + 0L + } else { + val keyData = mapData.keyArray() + val valData = mapData.valueArray() + (0 until mapData.numElements()).map { i => + calcLitValueSize(keyData.get(i, keyType), keyType) + + calcLitValueSize(valData.get(i, valType), valType) + }.sum + } case StructType(fields) => + // A special case that it should always go into children even lit is null. + // Because the children of fixed width will always take some memory. val stData = lit.asInstanceOf[InternalRow] fields.zipWithIndex.map { case (f, i) => - calcLitValueSize(stData.get(i, f.dataType), f.dataType) - }.sum - case MapType(keyType, valType, _) => - val mapData = lit.asInstanceOf[MapData] - val keyData = mapData.keyArray() - val valData = mapData.valueArray() - (0 until mapData.numElements()).map { i => - calcLitValueSize(keyData.get(i, keyType), keyType) + - calcLitValueSize(valData.get(i, valType), valType) + val fLit = if (stData == null) null else stData.get(i, f.dataType) + calcLitValueSize(fLit, f.dataType) }.sum case _ => litTp.defaultSize } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/unit/LiteralSizeEstimationTest.scala b/tests/src/test/scala/com/nvidia/spark/rapids/unit/LiteralSizeEstimationTest.scala new file mode 100644 index 00000000000..b50e23ae78e --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/unit/LiteralSizeEstimationTest.scala @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.unit + +import ai.rapids.cudf.ColumnVector +import com.nvidia.spark.rapids.{GpuScalar, GpuUnitTests, PreProjectSplitIterator} +import com.nvidia.spark.rapids.Arm.withResource + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** These tests only cover nested type literals for the PreProjectSplitIterator case */ +class LiteralSizeEstimationTest extends GpuUnitTests { + private val numRows = 1000 + + private def testLiteralSizeEstimate(lit: Any, litType: DataType): Unit = { + val col = withResource(GpuScalar.from(lit, litType))(ColumnVector.fromScalar(_, numRows)) + val actualSize = withResource(col)(_.getDeviceMemorySize) + val estimatedSize = PreProjectSplitIterator.calcSizeForLiteral(lit, litType, numRows) + assertResult(actualSize)(estimatedSize) + } + + test("estimate the array(int) literal size") { + val litType = ArrayType(IntegerType, true) + val lit = ArrayData.toArrayData(Array(null, 1, 2, null)) + testLiteralSizeEstimate(lit, litType) + } + + test("estimate the array(string) literal size") { + val litType = ArrayType(StringType, true) + val lit = ArrayData.toArrayData( + Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2"))) + testLiteralSizeEstimate(lit, litType) + } + + test("estimate the array(array(array(int))) literal size") { + val litType = ArrayType(ArrayType(ArrayType(IntegerType, true), true), true) + val nestedElem1 = ArrayData.toArrayData(Array(null, 1, 2, null)) + val nestedElem2 = ArrayData.toArrayData(Array(null)) + val nestedElem3 = ArrayData.toArrayData(Array()) + val elem1 = ArrayData.toArrayData(Array(nestedElem1, null)) + val elem2 = ArrayData.toArrayData(Array(nestedElem2, null, nestedElem3)) + val lit = ArrayData.toArrayData(Array(null, elem1, null, elem2, null)) + testLiteralSizeEstimate(lit, litType) + } + + test("estimate the array(array(array(string))) literal size") { + val litType = ArrayType(ArrayType(ArrayType(StringType, true), true), true) + val nestedElem1 = ArrayData.toArrayData( + Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2"))) + val nestedElem2 = ArrayData.toArrayData(Array(null)) + val nestedElem3 = ArrayData.toArrayData(Array()) + val elem1 = ArrayData.toArrayData(Array(nestedElem1, null)) + val elem2 = ArrayData.toArrayData(Array(nestedElem2, null, nestedElem3)) + val lit = ArrayData.toArrayData(Array(null, elem1, null, elem2, null)) + testLiteralSizeEstimate(lit, litType) + } + + test("estimate the struct(int, string) literal size") { + val litType = StructType(Seq( + StructField("int1", IntegerType), + StructField("string2", StringType) + )) + // null + testLiteralSizeEstimate(InternalRow(null, null), litType) + // normal case + testLiteralSizeEstimate(InternalRow(1, UTF8String.fromString("s1")), litType) + } + + test("estimate the struct(int, array(string)) literal size") { + val litType = StructType(Seq( + StructField("int1", IntegerType), + StructField("string2", ArrayType(StringType, true)) + )) + testLiteralSizeEstimate(InternalRow(null, null), litType) + val arrayLit = ArrayData.toArrayData( + Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2"))) + // normal case + testLiteralSizeEstimate(InternalRow(1, arrayLit), litType) + } + + test("estimate the list(struct(int, array(string))) literal size") { + val litType = ArrayType( + StructType(Seq( + StructField("int1", IntegerType), + StructField("string2", ArrayType(StringType, true)) + )), true) + val arrayLit = ArrayData.toArrayData( + Array(null, UTF8String.fromString("a1"), UTF8String.fromString("a2"))) + val elem1 = InternalRow(1, arrayLit) + val elem2 = InternalRow(null, null) + val lit = ArrayData.toArrayData(Array(null, elem1, elem2)) + testLiteralSizeEstimate(lit, litType) + } + + test("estimate the map(int, array(string)) literal size") { + val litType = MapType(IntegerType, ArrayType(StringType, true), true) + val arrayLit = ArrayData.toArrayData( + Array(null, UTF8String.fromString("s1"), UTF8String.fromString("s2"))) + val valueLit = ArrayData.toArrayData(Array(null, arrayLit)) + val keyLit = ArrayData.toArrayData(Array(1, 2)) + val lit = new ArrayBasedMapData(keyLit, valueLit) + testLiteralSizeEstimate(lit, litType) + } +}