@@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
3939import org .apache .spark .sql .catalyst .InternalRow
4040import org .apache .spark .sql .catalyst .expressions ._
4141import org .apache .spark .sql .catalyst .expressions .codegen .Block ._
42- import org .apache .spark .sql .catalyst .util .{ArrayData , MapData }
42+ import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData , MapData }
4343import org .apache .spark .sql .internal .SQLConf
4444import org .apache .spark .sql .types ._
4545import org .apache .spark .unsafe .Platform
@@ -746,73 +746,6 @@ class CodegenContext {
746746 """ .stripMargin
747747 }
748748
749- /**
750- * Generates code creating a [[UnsafeArrayData ]].
751- *
752- * @param arrayName name of the array to create
753- * @param numElements code representing the number of elements the array should contain
754- * @param elementType data type of the elements in the array
755- * @param additionalErrorMessage string to include in the error message
756- */
757- def createUnsafeArray (
758- arrayName : String ,
759- numElements : String ,
760- elementType : DataType ,
761- additionalErrorMessage : String ): String = {
762- val arraySize = freshName(" size" )
763- val arrayBytes = freshName(" arrayBytes" )
764-
765- s """
766- |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
767- | $numElements,
768- | ${elementType.defaultSize});
769- |if ( $arraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
770- | throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
771- | " bytes of data due to exceeding the limit " +
772- | " ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH } bytes for UnsafeArrayData." +
773- | " $additionalErrorMessage");
774- |}
775- |byte[] $arrayBytes = new byte[(int) $arraySize];
776- |UnsafeArrayData $arrayName = new UnsafeArrayData();
777- |Platform.putLong( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, $numElements);
778- | $arrayName.pointTo( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, (int) $arraySize);
779- """ .stripMargin
780- }
781-
782- /**
783- * Generates code creating a [[UnsafeArrayData ]]. The generated code executes
784- * a provided fallback when the size of backing array would exceed the array size limit.
785- * @param arrayName a name of the array to create
786- * @param numElements a piece of code representing the number of elements the array should contain
787- * @param elementSize a size of an element in bytes
788- * @param bodyCode a function generating code that fills up the [[UnsafeArrayData ]]
789- * and getting the backing array as a parameter
790- * @param fallbackCode a piece of code executed when the array size limit is exceeded
791- */
792- def createUnsafeArrayWithFallback (
793- arrayName : String ,
794- numElements : String ,
795- elementSize : Int ,
796- bodyCode : String => String ,
797- fallbackCode : String ): String = {
798- val arraySize = freshName(" size" )
799- val arrayBytes = freshName(" arrayBytes" )
800- s """
801- |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
802- | $numElements,
803- | $elementSize);
804- |if ( $arraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
805- | $fallbackCode
806- |} else {
807- | final byte[] $arrayBytes = new byte[(int) $arraySize];
808- | UnsafeArrayData $arrayName = new UnsafeArrayData();
809- | Platform.putLong( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, $numElements);
810- | $arrayName.pointTo( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, (int) $arraySize);
811- | ${bodyCode(arrayBytes)}
812- |}
813- """ .stripMargin
814- }
815-
816749 /**
817750 * Generates code to do null safe execution, i.e. only execute the code when the input is not
818751 * null by adding null check if necessary.
@@ -1490,6 +1423,59 @@ object CodeGenerator extends Logging {
14901423 }
14911424 }
14921425
1426+ /**
1427+ * Generates code creating a [[UnsafeArrayData ]] or [[GenericArrayData ]] based on
1428+ * given parameters.
1429+ *
1430+ * @param arrayName name of the array to create
1431+ * @param elementType data type of the elements in source array
1432+ * @param numElements code representing the number of elements the array should contain
1433+ * @param additionalErrorMessage string to include in the error message
1434+ *
1435+ * @return code representing the allocation of [[ArrayData ]]
1436+ */
1437+ def createArrayData (
1438+ arrayName : String ,
1439+ elementType : DataType ,
1440+ numElements : String ,
1441+ additionalErrorMessage : String ): String = {
1442+ val elementSize = if (CodeGenerator .isPrimitiveType(elementType)) {
1443+ elementType.defaultSize
1444+ } else {
1445+ - 1
1446+ }
1447+ s """
1448+ |ArrayData $arrayName = ArrayData.allocateArrayData(
1449+ | $elementSize, $numElements, " $additionalErrorMessage");
1450+ """ .stripMargin
1451+ }
1452+
1453+ /**
1454+ * Generates assignment code for an [[ArrayData ]]
1455+ *
1456+ * @param dstArray name of the array to be assigned
1457+ * @param elementType data type of the elements in destination and source arrays
1458+ * @param srcArray name of the array to be read
1459+ * @param needNullCheck value which shows whether a nullcheck is required for the returning
1460+ * assignment
1461+ * @param dstArrayIndex an index variable to access each element of destination array
1462+ * @param srcArrayIndex an index variable to access each element of source array
1463+ *
1464+ * @return code representing an assignment to each element of the [[ArrayData ]], which requires
1465+ * a pair of destination and source loop index variables
1466+ */
1467+ def createArrayAssignment (
1468+ dstArray : String ,
1469+ elementType : DataType ,
1470+ srcArray : String ,
1471+ dstArrayIndex : String ,
1472+ srcArrayIndex : String ,
1473+ needNullCheck : Boolean ): String = {
1474+ CodeGenerator .setArrayElement(dstArray, elementType, dstArrayIndex,
1475+ CodeGenerator .getValue(srcArray, elementType, srcArrayIndex),
1476+ if (needNullCheck) Some (s " $srcArray.isNullAt( $srcArrayIndex) " ) else None )
1477+ }
1478+
14931479 /**
14941480 * Returns the code to update a column in Row for a given DataType.
14951481 */
@@ -1558,6 +1544,34 @@ object CodeGenerator extends Logging {
15581544 }
15591545 }
15601546
1547+ /**
1548+ * Generates code of setter for an [[ArrayData ]].
1549+ */
1550+ def setArrayElement (
1551+ array : String ,
1552+ elementType : DataType ,
1553+ i : String ,
1554+ value : String ,
1555+ isNull : Option [String ] = None ): String = {
1556+ val isPrimitiveType = CodeGenerator .isPrimitiveType(elementType)
1557+ val setFunc = if (isPrimitiveType) {
1558+ s " set ${CodeGenerator .primitiveTypeName(elementType)}"
1559+ } else {
1560+ " update"
1561+ }
1562+ if (isNull.isDefined && isPrimitiveType) {
1563+ s """
1564+ |if ( ${isNull.get}) {
1565+ | $array.setNullAt( $i);
1566+ |} else {
1567+ | $array. $setFunc( $i, $value);
1568+ |}
1569+ """ .stripMargin
1570+ } else {
1571+ s " $array. $setFunc( $i, $value); "
1572+ }
1573+ }
1574+
15611575 /**
15621576 * Returns the specialized code to set a given value in a column vector for a given `DataType`
15631577 * that could potentially be nullable.
0 commit comments