Skip to content

Commit dbbfd61

Browse files
committed
[SPARK-49044] ValidateExternalType should return child in error
1 parent fb8d01a commit dbbfd61

File tree

6 files changed

+118
-9
lines changed

6 files changed

+118
-9
lines changed

common/utils/src/main/resources/error/error-conditions.json

+6
Original file line numberDiff line numberDiff line change
@@ -2177,6 +2177,12 @@
21772177
],
21782178
"sqlState" : "42001"
21792179
},
2180+
"INVALID_EXTERNAL_TYPE" : {
2181+
"message" : [
2182+
"The external type <externalType> is not valid for the type <type> at the expression <expr>."
2183+
],
2184+
"sqlState" : "42K0N"
2185+
},
21802186
"INVALID_EXTRACT_BASE_FIELD_TYPE" : {
21812187
"message" : [
21822188
"Can't extract a value from <base>. Need a complex type [STRUCT, ARRAY, MAP] but got <other>."

common/utils/src/main/resources/error/error-states.json

+6
Original file line numberDiff line numberDiff line change
@@ -4625,6 +4625,12 @@
46254625
"standard": "N",
46264626
"usedBy": ["Spark"]
46274627
},
4628+
"42K0N": {
4629+
"description": "Invalid external type.",
4630+
"origin": "Spark",
4631+
"standard": "N",
4632+
"usedBy": ["Spark"]
4633+
},
46284634
"42KD0": {
46294635
"description": "Ambiguous name reference.",
46304636
"origin": "Databricks",

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

+11-7
Original file line numberDiff line numberDiff line change
@@ -2023,8 +2023,6 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
20232023

20242024
override val dataType: DataType = externalDataType
20252025

2026-
private lazy val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"
2027-
20282026
private lazy val checkType: (Any) => Boolean = expected match {
20292027
case _: DecimalType =>
20302028
(value: Any) => {
@@ -2057,14 +2055,12 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
20572055
if (checkType(input)) {
20582056
input
20592057
} else {
2060-
throw new RuntimeException(s"${input.getClass.getName}$errMsg")
2058+
throw QueryExecutionErrors.invalidExternalTypeError(
2059+
input.getClass.getName, expected, child)
20612060
}
20622061
}
20632062

20642063
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
2065-
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
2066-
// because errMsgField is used only when the type doesn't match.
2067-
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
20682064
val input = child.genCode(ctx)
20692065
val obj = input.value
20702066
def genCheckTypes(classes: Seq[Class[_]]): String = {
@@ -2090,14 +2086,22 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
20902086
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
20912087
}
20922088

2089+
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
2090+
// because errMsgField is used only when the type doesn't match.
2091+
val expectedTypeField = ctx.addReferenceObj(
2092+
"expectedTypeField", expected)
2093+
val childExpressionMsgField = ctx.addReferenceObj(
2094+
"childExpressionMsgField", child)
2095+
20932096
val code = code"""
20942097
${input.code}
20952098
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
20962099
if (!${input.isNull}) {
20972100
if ($typeCheck) {
20982101
${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
20992102
} else {
2100-
throw new RuntimeException($obj.getClass().getName() + $errMsgField);
2103+
throw QueryExecutionErrors.invalidExternalTypeError(
2104+
$obj.getClass().getName(), $expectedTypeField, $childExpressionMsgField);
21012105
}
21022106
}
21032107

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

+14
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
478478
)
479479
}
480480

481+
def invalidExternalTypeError(
482+
actualType: String,
483+
expectedType: DataType,
484+
childExpression: Expression): SparkRuntimeException = {
485+
new SparkRuntimeException(
486+
errorClass = "INVALID_EXTERNAL_TYPE",
487+
messageParameters = Map(
488+
"externalType" -> actualType,
489+
"type" -> toSQLType(expectedType),
490+
"expr" -> toSQLExpr(childExpression)
491+
)
492+
)
493+
}
494+
481495
def notOverrideExpectedMethodsError(
482496
className: String, m1: String, m2: String): SparkRuntimeException = {
483497
new SparkRuntimeException(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

+44-2
Original file line numberDiff line numberDiff line change
@@ -547,13 +547,55 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
547547
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
548548
}
549549

550-
checkExceptionInExpression[RuntimeException](
550+
checkExceptionInExpression[SparkRuntimeException](
551551
ValidateExternalType(
552552
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
553553
DoubleType,
554554
DoubleType),
555555
InternalRow.fromSeq(Seq(Row(1))),
556-
"java.lang.Integer is not a valid external type for schema of double")
556+
"The external type java.lang.Integer is not valid for the type \"DOUBLE\"")
557+
}
558+
559+
test("SPARK-49044 ValidateExternalType should return child in error") {
560+
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
561+
Seq(
562+
(true, BooleanType),
563+
(2.toByte, ByteType),
564+
(5.toShort, ShortType),
565+
(23, IntegerType),
566+
(61L, LongType),
567+
(1.0f, FloatType),
568+
(10.0, DoubleType),
569+
("abcd".getBytes, BinaryType),
570+
("abcd", StringType),
571+
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
572+
(IntervalUtils.stringToInterval(UTF8String.fromString("interval 3 day")),
573+
CalendarIntervalType),
574+
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
575+
(Array(3, 2, 1), ArrayType(IntegerType))
576+
).foreach { case (input, dt) =>
577+
val enc = RowEncoder.encoderForDataType(dt, lenient = false)
578+
val validateType = ValidateExternalType(
579+
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
580+
dt,
581+
EncoderUtils.lenientExternalDataTypeFor(enc))
582+
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
583+
}
584+
585+
checkErrorInExpression[SparkRuntimeException](
586+
expression = ValidateExternalType(
587+
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
588+
DoubleType,
589+
DoubleType),
590+
inputRow = InternalRow.fromSeq(Seq(Row(1))),
591+
errorClass = "INVALID_EXTERNAL_TYPE",
592+
parameters = Map[String, String](
593+
"externalType" -> "java.lang.Integer",
594+
"type" -> "\"DOUBLE\"",
595+
"expr" -> ("\"getexternalrowfield(input[0, org.apache.spark.sql.Row, true], " +
596+
"0, c0)\"")
597+
)
598+
)
557599
}
558600

559601
private def javaMapSerializerFor(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.apache.spark.sql.catalyst.expressions
2+
3+
import scala.jdk.CollectionConverters._
4+
import org.scalatest.matchers.should.Matchers
5+
import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
6+
import org.apache.spark.sql.Row
7+
import org.apache.spark.sql.test.SharedSparkSession
8+
import org.apache.spark.sql.types.{StringType, StructType}
9+
10+
class ValidateExternalTypeSuite extends SparkFunSuite
11+
with SharedSparkSession with Matchers {
12+
test("SPARK-49044 ValidateExternalType should return child in error") {
13+
val seq: Seq[Row] = Seq(
14+
Row(
15+
"".toCharArray.map(_.toByte),
16+
),
17+
)
18+
val schema: StructType = new StructType()
19+
.add("f3", StringType)
20+
21+
val exception = intercept[SparkRuntimeException] {
22+
sqlContext.createDataFrame(sparkContext.parallelize(seq), schema).show()
23+
}
24+
25+
assert(
26+
exception.getCause.asInstanceOf[SparkRuntimeException].getErrorClass
27+
== "INVALID_EXTERNAL_TYPE"
28+
)
29+
val expected = Map(
30+
"externalType" -> "[B",
31+
"type" -> "\"STRING\"",
32+
"expr" -> ("\"getexternalrowfield(assertnotnull(" +
33+
"input[0, org.apache.spark.sql.Row, true]), 0, f3)\"")
34+
).asJava
35+
exception.getCause.asInstanceOf[SparkRuntimeException].getMessageParameters shouldBe expected
36+
}
37+
}

0 commit comments

Comments
 (0)