Skip to content

Commit

Permalink
[SPARK-49044] ValidateExternalType should return child in error
Browse files Browse the repository at this point in the history
  • Loading branch information
mrk-andreev committed Sep 1, 2024
1 parent fb8d01a commit f66a4ad
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2806,6 +2806,12 @@
],
"sqlState" : "42602"
},
"INVALID_EXTERNAL_TYPE": {
"message": [
"<actualType> is not a valid external type for schema of <expectedType> at <childExpression>"
],
"sqlState" : "42K0N"
},
"INVALID_SET_SYNTAX" : {
"message" : [
"Expected format is 'SET', 'SET key', or 'SET key=value'. If you want to include special characters in key, or include semicolon in value, please use backquotes, e.g., SET `key`=`value`."
Expand Down
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-states.json
Original file line number Diff line number Diff line change
Expand Up @@ -4625,6 +4625,12 @@
"standard": "N",
"usedBy": ["Spark"]
},
"42K0N": {
"description": "Invalid external type.",
"origin": "Spark",
"standard": "N",
"usedBy": ["Spark"]
},
"42KD0": {
"description": "Ambiguous name reference.",
"origin": "Databricks",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2023,8 +2023,6 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD

override val dataType: DataType = externalDataType

private lazy val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"

private lazy val checkType: (Any) => Boolean = expected match {
case _: DecimalType =>
(value: Any) => {
Expand Down Expand Up @@ -2057,14 +2055,12 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
if (checkType(input)) {
input
} else {
throw new RuntimeException(s"${input.getClass.getName}$errMsg")
throw QueryExecutionErrors.invalidExternalTypeError(
input.getClass.getName, expected.simpleString, child)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val input = child.genCode(ctx)
val obj = input.value
def genCheckTypes(classes: Seq[Class[_]]): String = {
Expand All @@ -2090,14 +2086,22 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}

// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
val expectedTypeMsgField = ctx.addReferenceObj(
"expectedTypeMsgField", expected.simpleString)
val childExpressionMsgField = ctx.addReferenceObj(
"childExpressionMsgField", child)

val code = code"""
${input.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {
if ($typeCheck) {
${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
} else {
throw new RuntimeException($obj.getClass().getName() + $errMsgField);
throw QueryExecutionErrors.invalidExternalTypeError(
$obj.getClass().getName(), $expectedTypeMsgField, $childExpressionMsgField);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,18 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

def invalidExternalTypeError(
actualType: String, expectedType: String, childExpression: Expression): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "INVALID_EXTERNAL_TYPE",
messageParameters = Map(
"actualType" -> actualType,
"expectedType" -> expectedType,
"childExpression" -> toSQLExpr(childExpression)
)
)
}

def notOverrideExpectedMethodsError(
className: String, m1: String, m2: String): SparkRuntimeException = {
new SparkRuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
}

checkExceptionInExpression[RuntimeException](
checkExceptionInExpression[SparkRuntimeException](
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
DoubleType,
Expand All @@ -556,6 +556,48 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"java.lang.Integer is not a valid external type for schema of double")
}

test("SPARK-49044 ValidateExternalType should return child in error") {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
Seq(
(true, BooleanType),
(2.toByte, ByteType),
(5.toShort, ShortType),
(23, IntegerType),
(61L, LongType),
(1.0f, FloatType),
(10.0, DoubleType),
("abcd".getBytes, BinaryType),
("abcd", StringType),
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
(IntervalUtils.stringToInterval(UTF8String.fromString("interval 3 day")),
CalendarIntervalType),
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
val enc = RowEncoder.encoderForDataType(dt, lenient = false)
val validateType = ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
dt,
EncoderUtils.lenientExternalDataTypeFor(enc))
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
}

checkErrorInExpression[SparkRuntimeException](
expression = ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
DoubleType,
DoubleType),
inputRow = InternalRow.fromSeq(Seq(Row(1))),
errorClass = "INVALID_EXTERNAL_TYPE",
parameters = Map[String, String](
"actualType" -> "java.lang.Integer",
"expectedType" -> "double",
"childExpression" -> ("\"getexternalrowfield(input[0, org.apache.spark.sql.Row, true], " +
"0, c0)\"")
)
)
}

private def javaMapSerializerFor(
keyClazz: Class[_],
valueClazz: Class[_])(inputObject: Expression): Expression = {
Expand Down

0 comments on commit f66a4ad

Please sign in to comment.