Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,11 @@ import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{
ARRAYS_ZIP,
CONCAT,
MAP_FROM_ENTRIES,
TreePattern
}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, MAP_FROM_ENTRIES, TreePattern}
import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -4211,15 +4205,17 @@ case class ArrayDistinct(child: Expression)
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new SQLOpenHashSet[Any]()
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
(value: Any) => {
val normalized = SQLOpenHashSet.normalizeZero(value)
if (!hs.contains(normalized)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError(
prettyName, arrayBuffer.size)
}
arrayBuffer += value
hs.add(value)
},
arrayBuffer += normalized
hs.add(normalized)
}
},
(valueNaN: Any) => arrayBuffer += valueNaN)
val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs,
(value: Any) => withNaNCheckFunc(value),
Expand All @@ -4231,6 +4227,8 @@ case class ArrayDistinct(child: Expression)
}
new GenericArrayData(arrayBuffer)
} else {
// Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal
// via ordering.equiv. However, nested -0.0 values are not normalized to 0.0.
(data: ArrayData) => {
val array = data.toArray[AnyRef](elementType)
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
Expand Down Expand Up @@ -4276,6 +4274,7 @@ case class ArrayDistinct(child: Expression)
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val normalizedValue = ctx.freshName("normalizedValue")

// Only need to track null element index when array's element is nullable.
val declareNullTrackVariables = if (resultArrayElementNullable) {
Expand All @@ -4286,14 +4285,17 @@ case class ArrayDistinct(child: Expression)
""
}

val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value)

val body =
s"""
|if (!$hashSet.contains($hsValueCast$value)) {
|$jt $normalizedValue = $normalizeCode;
|if (!$hashSet.contains($hsValueCast$normalizedValue)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
| $hashSet.add$hsPostFix($hsValueCast$normalizedValue);
| $builder.$$plus$$eq($normalizedValue);
|}
""".stripMargin

Expand Down Expand Up @@ -4385,15 +4387,17 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new SQLOpenHashSet[Any]()
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
(value: Any) => {
val normalized = SQLOpenHashSet.normalizeZero(value)
if (!hs.contains(normalized)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError(
prettyName, arrayBuffer.size)
}
arrayBuffer += value
hs.add(value)
},
arrayBuffer += normalized
hs.add(normalized)
}
},
(valueNaN: Any) => arrayBuffer += valueNaN)
val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs,
(value: Any) => withNaNCheckFunc(value),
Expand All @@ -4408,6 +4412,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
}
new GenericArrayData(arrayBuffer)
} else {
// Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal
// via ordering.equiv. However, nested -0.0 values are not normalized to 0.0.
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadyIncludeNull = false
Expand Down Expand Up @@ -4468,15 +4474,19 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val normalizedValue = ctx.freshName("normalizedValue")

val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value)

val body =
s"""
|if (!$hashSet.contains($hsValueCast$value)) {
|$jt $normalizedValue = $normalizeCode;
|if (!$hashSet.contains($hsValueCast$normalizedValue)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
| $hashSet.add$hsPostFix($hsValueCast$normalizedValue);
| $builder.$$plus$$eq($normalizedValue);
|}
""".stripMargin

Expand Down Expand Up @@ -4539,7 +4549,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
}

/**
* Returns an array of the elements in the intersect of x and y, without duplicates
* Returns an array of the elements in the intersection of x and y, without duplicates
*/
@ExpressionDescription(
usage = """
Expand Down Expand Up @@ -4571,18 +4581,20 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
val hsResult = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(value: Any) => hs.add(SQLOpenHashSet.normalizeZero(value)),
(valueNaN: Any) => {} )
val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs,
(value: Any) => withArray2NaNCheckFunc(value),
() => {}
)
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult,
(value: Any) =>
if (hs.contains(value) && !hsResult.contains(value)) {
arrayBuffer += value
hsResult.add(value)
},
(value: Any) => {
val normalized = SQLOpenHashSet.normalizeZero(value)
if (hs.contains(normalized) && !hsResult.contains(normalized)) {
arrayBuffer += normalized
hsResult.add(normalized)
}
},
(valueNaN: Any) =>
if (hs.containsNaN()) {
arrayBuffer += valueNaN
Expand Down Expand Up @@ -4610,6 +4622,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
new GenericArrayData(Array.emptyObjectArray)
}
} else {
// Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal
// via ordering.equiv. However, nested -0.0 values are not normalized to 0.0.
(array1, array2) =>
if (array1.numElements() != 0 && array2.numElements() != 0) {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
Expand Down Expand Up @@ -4684,12 +4698,16 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
val hashSetResult = ctx.freshName("hashSetResult")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val normalizedValue = ctx.freshName("normalizedValue")

val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value)

val withArray2NaNCheckCodeGenerator =
(array: String, index: String) =>
s"$jt $value = ${genGetValue(array, index)};" +
s"""$jt $value = ${genGetValue(array, index)};
|$jt $normalizedValue = $normalizeCode;""".stripMargin +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);",
s"$hashSet.add$hsPostFix($hsValueCast$normalizedValue);",
(valueNaN: String) => "")

val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
Expand All @@ -4698,13 +4716,14 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina

val body =
s"""
|if ($hashSet.contains($hsValueCast$value) &&
| !$hashSetResult.contains($hsValueCast$value)) {
|$jt $normalizedValue = $normalizeCode;
|if ($hashSet.contains($hsValueCast$normalizedValue) &&
| !$hashSetResult.contains($hsValueCast$normalizedValue)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSetResult.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
| $hashSetResult.add$hsPostFix($hsValueCast$normalizedValue);
| $builder.$$plus$$eq($normalizedValue);
|}
""".stripMargin

Expand Down Expand Up @@ -4771,7 +4790,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
}

/**
* Returns an array of the elements in the intersect of x and y, without duplicates
* Returns an array of the elements in x but not in y, without duplicates
*/
@ExpressionDescription(
usage = """
Expand Down Expand Up @@ -4801,18 +4820,20 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
val hs = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(value: Any) => hs.add(SQLOpenHashSet.normalizeZero(value)),
(valueNaN: Any) => {})
val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs,
(value: Any) => withArray2NaNCheckFunc(value),
() => {}
)
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
arrayBuffer += value
hs.add(value)
},
(value: Any) => {
val normalized = SQLOpenHashSet.normalizeZero(value)
if (!hs.contains(normalized)) {
arrayBuffer += normalized
hs.add(normalized)
}
},
(valueNaN: Any) => arrayBuffer += valueNaN)
val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs,
(value: Any) => withArray1NaNCheckFunc(value),
Expand All @@ -4830,6 +4851,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
}
new GenericArrayData(arrayBuffer)
} else {
// Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal
// via ordering.equiv. However, nested -0.0 values are not normalized to 0.0.
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var scannedNullElements = false
Expand Down Expand Up @@ -4900,12 +4923,16 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val normalizedValue = ctx.freshName("normalizedValue")

val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value)

val withArray2NaNCheckCodeGenerator =
(array: String, index: String) =>
s"$jt $value = ${genGetValue(array, i)};" +
s"""$jt $value = ${genGetValue(array, i)};
|$jt $normalizedValue = $normalizeCode;""".stripMargin +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);",
s"$hashSet.add$hsPostFix($hsValueCast$normalizedValue);",
(valueNaN: Any) => "")

val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode(
Expand All @@ -4914,12 +4941,13 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL

val body =
s"""
|if (!$hashSet.contains($hsValueCast$value)) {
|$jt $normalizedValue = $normalizeCode;
|if (!$hashSet.contains($hsValueCast$normalizedValue)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
| $hashSet.add$hsPostFix($hsValueCast$normalizedValue);
| $builder.$$plus$$eq($normalizedValue);
|}
""".stripMargin

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}
import org.apache.spark.util.collection.OpenHashSet

// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic.
// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN and negative zero
// w.r.t. the SQL semantic.
@Private
class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
Expand Down Expand Up @@ -161,4 +162,28 @@ object SQLOpenHashSet {
""".stripMargin
}.getOrElse(handleNotNaN)
}

/**
* Normalizes negative zero to positive zero.
* IEEE 754 defines -0.0 == 0.0, but they have different binary representations
* and thus different hash codes, causing issues in hash-based collections used in SQL operations.
*/
def normalizeZero(value: Any): Any = value match {
case d: java.lang.Double if d == -0.0d => 0.0d
case f: java.lang.Float if f == -0.0f => 0.0f
case _ => value
}

/**
* Generates code to normalize negative zero to positive zero.
*/
def normalizeZeroCode(dataType: DataType, valueName: String): String = {
dataType match {
case DoubleType =>
s"($valueName == -0.0d ? 0.0d : $valueName)"
case FloatType =>
s"($valueName == -0.0f ? 0.0f : $valueName)"
case _ => valueName
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3111,4 +3111,38 @@ class CollectionExpressionsSuite
a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", null, "q")
)
}

test("SPARK-54918: array operations should normalize -0.0 to 0.0") {
// Double
val ad = Literal.create(Seq(0.0d, -0.0d, 1.0d), ArrayType(DoubleType))
checkEvaluation(ArrayDistinct(ad), Seq(0.0d, 1.0d))

val au1 = Literal.create(Seq(0.0d), ArrayType(DoubleType))
val au2 = Literal.create(Seq(-0.0d), ArrayType(DoubleType))
checkEvaluation(ArrayUnion(au1, au2), Seq(0.0d))

val ai1 = Literal.create(Seq(0.0d, 1.0d), ArrayType(DoubleType))
val ai2 = Literal.create(Seq(-0.0d, 2.0d), ArrayType(DoubleType))
checkEvaluation(ArrayIntersect(ai1, ai2), Seq(0.0d))

val ae1 = Literal.create(Seq(0.0d, 1.0d), ArrayType(DoubleType))
val ae2 = Literal.create(Seq(-0.0d), ArrayType(DoubleType))
checkEvaluation(ArrayExcept(ae1, ae2), Seq(1.0d))

// Float
val adf = Literal.create(Seq(0.0f, -0.0f, 1.0f), ArrayType(FloatType))
checkEvaluation(ArrayDistinct(adf), Seq(0.0f, 1.0f))

val au1f = Literal.create(Seq(0.0f), ArrayType(FloatType))
val au2f = Literal.create(Seq(-0.0f), ArrayType(FloatType))
checkEvaluation(ArrayUnion(au1f, au2f), Seq(0.0f))

val ai1f = Literal.create(Seq(0.0f, 1.0f), ArrayType(FloatType))
val ai2f = Literal.create(Seq(-0.0f, 2.0f), ArrayType(FloatType))
checkEvaluation(ArrayIntersect(ai1f, ai2f), Seq(0.0f))

val ae1f = Literal.create(Seq(0.0f, 1.0f), ArrayType(FloatType))
val ae2f = Literal.create(Seq(-0.0f), ArrayType(FloatType))
checkEvaluation(ArrayExcept(ae1f, ae2f), Seq(1.0f))
}
}