Skip to content

Commit

Permalink
[SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN and Float…
Browse files Browse the repository at this point in the history
….NaN

### What changes were proposed in this pull request?
For query
```
select array_except(array(cast('nan' as double), 1d), array(cast('nan' as double)))
```
This returns [NaN, 1d], but it should return [1d].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on #33955

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
ArrayExcept won't show handle equal `NaN` value

### How was this patch tested?
Added UT

Closes #33994 from AngersZhuuuu/SPARK-36753.

Authored-by: Angerszhuuuu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit a7cbe69)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
AngersZhuuuu authored and cloud-fan committed Sep 22, 2021
1 parent fc0b85f commit 2ff038a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -4109,32 +4108,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
@transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
val hs = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(valueNaN: Any) => {})
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
arrayBuffer += value
hs.add(value)
},
(valueNaN: Any) => arrayBuffer += valueNaN)
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
notFoundNullElement = false
hs.addNull()
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
withArray2NaNCheckFunc(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (notFoundNullElement) {
if (!hs.containsNull()) {
arrayBuffer += null
notFoundNullElement = false
hs.addNull()
}
} else {
val elem = array1.get(i, elementType)
if (!hs.contains(elem)) {
arrayBuffer += elem
hs.add(elem)
}
withArray1NaNCheckFunc(elem)
}
i += 1
}
Expand Down Expand Up @@ -4203,10 +4208,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
val ptName = CodeGenerator.primitiveTypeName(jt)

nullSafeCodeGen(ctx, ev, (array1, array2) => {
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
Expand All @@ -4217,7 +4221,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array2.isNullAt($i)) {
| $notFoundNullElement = false;
| $hashSet.addNull();
|} else {
| $body
|}
Expand All @@ -4235,18 +4239,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
}

val writeArray2ToHashSet = withArray2NullCheck(
s"""
|$jt $value = ${genGetValue(array2, i)};
|$hashSet.add$hsPostFix($hsValueCast$value);
""".stripMargin)
s"$jt $value = ${genGetValue(array2, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);",
(valueNaN: Any) => ""))

def withArray1NullAssignment(body: String) =
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array1.isNullAt($i)) {
| if ($notFoundNullElement) {
| if (!$hashSet.containsNull()) {
| $hashSet.addNull();
| $nullElementIndex = $size;
| $notFoundNullElement = false;
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
Expand All @@ -4258,22 +4262,29 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
body
}

val processArray1 = withArray1NullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array1, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin

val processArray1 = withArray1NullAssignment(
s"$jt $value = ${genGetValue(array1, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
(valueNaN: String) =>
s"""
|$size++;
|$builder.$$plus$$eq($valueNaN);
""".stripMargin))

// Only need to track null element index when array1's element is nullable.
val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $notFoundNullElement = true;
|int $nullElementIndex = -1;
""".stripMargin
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}

test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayExcept(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))),
Seq(1f))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))),
Seq(1f))
}

test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),
Expand Down

0 comments on commit 2ff038a

Please sign in to comment.