diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6dddd9e6646c3..ac78b1f946d02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,17 +27,11 @@ import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} -import org.apache.spark.sql.catalyst.trees.TreePattern.{ - ARRAYS_ZIP, - CONCAT, - MAP_FROM_ENTRIES, - TreePattern -} +import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, MAP_FROM_ENTRIES, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -4211,15 +4205,17 @@ case class ArrayDistinct(child: Expression) val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new SQLOpenHashSet[Any]() val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { + (value: Any) => { + val normalized = SQLOpenHashSet.normalizeZero(value) + if (!hs.contains(normalized)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( prettyName, arrayBuffer.size) } - arrayBuffer += value - hs.add(value) - }, + arrayBuffer += normalized + hs.add(normalized) + } + }, (valueNaN: Any) => arrayBuffer += valueNaN) val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, (value: Any) => withNaNCheckFunc(value), @@ -4231,6 +4227,8 @@ case class ArrayDistinct(child: Expression) } new GenericArrayData(arrayBuffer) } else { + // Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal + // via ordering.equiv. However, nested -0.0 values are not normalized to 0.0. (data: ArrayData) => { val array = data.toArray[AnyRef](elementType) val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef] @@ -4276,6 +4274,7 @@ case class ArrayDistinct(child: Expression) val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val normalizedValue = ctx.freshName("normalizedValue") // Only need to track null element index when array's element is nullable. val declareNullTrackVariables = if (resultArrayElementNullable) { @@ -4286,14 +4285,17 @@ case class ArrayDistinct(child: Expression) "" } + val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value) + val body = s""" - |if (!$hashSet.contains($hsValueCast$value)) { + |$jt $normalizedValue = $normalizeCode; + |if (!$hashSet.contains($hsValueCast$normalizedValue)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hashSet.add$hsPostFix($hsValueCast$value); - | $builder.$$plus$$eq($value); + | $hashSet.add$hsPostFix($hsValueCast$normalizedValue); + | $builder.$$plus$$eq($normalizedValue); |} """.stripMargin @@ -4385,15 +4387,17 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new SQLOpenHashSet[Any]() val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { + (value: Any) => { + val normalized = SQLOpenHashSet.normalizeZero(value) + if (!hs.contains(normalized)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( prettyName, arrayBuffer.size) } - arrayBuffer += value - hs.add(value) - }, + arrayBuffer += normalized + hs.add(normalized) + } + }, (valueNaN: Any) => arrayBuffer += valueNaN) val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, (value: Any) => withNaNCheckFunc(value), @@ -4408,6 +4412,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } new GenericArrayData(arrayBuffer) } else { + // Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal + // via ordering.equiv. However, nested -0.0 values are not normalized to 0.0. (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] var alreadyIncludeNull = false @@ -4468,15 +4474,19 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val normalizedValue = ctx.freshName("normalizedValue") + + val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value) val body = s""" - |if (!$hashSet.contains($hsValueCast$value)) { + |$jt $normalizedValue = $normalizeCode; + |if (!$hashSet.contains($hsValueCast$normalizedValue)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hashSet.add$hsPostFix($hsValueCast$value); - | $builder.$$plus$$eq($value); + | $hashSet.add$hsPostFix($hsValueCast$normalizedValue); + | $builder.$$plus$$eq($normalizedValue); |} """.stripMargin @@ -4539,7 +4549,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } /** - * Returns an array of the elements in the intersect of x and y, without duplicates + * Returns an array of the elements in the intersection of x and y, without duplicates */ @ExpressionDescription( usage = """ @@ -4571,18 +4581,20 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina val hsResult = new SQLOpenHashSet[Any] val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => hs.add(value), + (value: Any) => hs.add(SQLOpenHashSet.normalizeZero(value)), (valueNaN: Any) => {} ) val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, (value: Any) => withArray2NaNCheckFunc(value), () => {} ) val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult, - (value: Any) => - if (hs.contains(value) && !hsResult.contains(value)) { - arrayBuffer += value - hsResult.add(value) - }, + (value: Any) => { + val normalized = SQLOpenHashSet.normalizeZero(value) + if (hs.contains(normalized) && !hsResult.contains(normalized)) { + arrayBuffer += normalized + hsResult.add(normalized) + } + }, (valueNaN: Any) => if (hs.containsNaN()) { arrayBuffer += valueNaN @@ -4610,6 +4622,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina new GenericArrayData(Array.emptyObjectArray) } } else { + // Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal + // via ordering.equiv. However, nested -0.0 values are not normalized to 0.0. (array1, array2) => if (array1.numElements() != 0 && array2.numElements() != 0) { val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] @@ -4684,12 +4698,16 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina val hashSetResult = ctx.freshName("hashSetResult") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val normalizedValue = ctx.freshName("normalizedValue") + + val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value) val withArray2NaNCheckCodeGenerator = (array: String, index: String) => - s"$jt $value = ${genGetValue(array, index)};" + + s"""$jt $value = ${genGetValue(array, index)}; + |$jt $normalizedValue = $normalizeCode;""".stripMargin + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, - s"$hashSet.add$hsPostFix($hsValueCast$value);", + s"$hashSet.add$hsPostFix($hsValueCast$normalizedValue);", (valueNaN: String) => "") val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode( @@ -4698,13 +4716,14 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina val body = s""" - |if ($hashSet.contains($hsValueCast$value) && - | !$hashSetResult.contains($hsValueCast$value)) { + |$jt $normalizedValue = $normalizeCode; + |if ($hashSet.contains($hsValueCast$normalizedValue) && + | !$hashSetResult.contains($hsValueCast$normalizedValue)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hashSetResult.add$hsPostFix($hsValueCast$value); - | $builder.$$plus$$eq($value); + | $hashSetResult.add$hsPostFix($hsValueCast$normalizedValue); + | $builder.$$plus$$eq($normalizedValue); |} """.stripMargin @@ -4771,7 +4790,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina } /** - * Returns an array of the elements in the intersect of x and y, without duplicates + * Returns an array of the elements in x but not in y, without duplicates */ @ExpressionDescription( usage = """ @@ -4801,18 +4820,20 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL val hs = new SQLOpenHashSet[Any] val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => hs.add(value), + (value: Any) => hs.add(SQLOpenHashSet.normalizeZero(value)), (valueNaN: Any) => {}) val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, (value: Any) => withArray2NaNCheckFunc(value), () => {} ) val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { - arrayBuffer += value - hs.add(value) - }, + (value: Any) => { + val normalized = SQLOpenHashSet.normalizeZero(value) + if (!hs.contains(normalized)) { + arrayBuffer += normalized + hs.add(normalized) + } + }, (valueNaN: Any) => arrayBuffer += valueNaN) val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, (value: Any) => withArray1NaNCheckFunc(value), @@ -4830,6 +4851,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL } new GenericArrayData(arrayBuffer) } else { + // Note: For complex types, deduplication correctly identifies -0.0 and 0.0 as equal + // via ordering.equiv. However, nested -0.0 values are not normalized to 0.0. (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] var scannedNullElements = false @@ -4900,12 +4923,16 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val normalizedValue = ctx.freshName("normalizedValue") + + val normalizeCode = SQLOpenHashSet.normalizeZeroCode(elementType, value) val withArray2NaNCheckCodeGenerator = (array: String, index: String) => - s"$jt $value = ${genGetValue(array, i)};" + + s"""$jt $value = ${genGetValue(array, i)}; + |$jt $normalizedValue = $normalizeCode;""".stripMargin + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, - s"$hashSet.add$hsPostFix($hsValueCast$value);", + s"$hashSet.add$hsPostFix($hsValueCast$normalizedValue);", (valueNaN: Any) => "") val writeArray2ToHashSet = SQLOpenHashSet.withNullCheckCode( @@ -4914,12 +4941,13 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL val body = s""" - |if (!$hashSet.contains($hsValueCast$value)) { + |$jt $normalizedValue = $normalizeCode; + |if (!$hashSet.contains($hsValueCast$normalizedValue)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hashSet.add$hsPostFix($hsValueCast$value); - | $builder.$$plus$$eq($value); + | $hashSet.add$hsPostFix($hsValueCast$normalizedValue); + | $builder.$$plus$$eq($normalizedValue); |} """.stripMargin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala index 10023c2ecca5d..0925f46ddffa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} import org.apache.spark.util.collection.OpenHashSet -// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic. +// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN and negative zero +// w.r.t. the SQL semantic. @Private class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, @@ -161,4 +162,28 @@ object SQLOpenHashSet { """.stripMargin }.getOrElse(handleNotNaN) } + + /** + * Normalizes negative zero to positive zero. + * IEEE 754 defines -0.0 == 0.0, but they have different binary representations + * and thus different hash codes, causing issues in hash-based collections used in SQL operations. + */ + def normalizeZero(value: Any): Any = value match { + case d: java.lang.Double if d == -0.0d => 0.0d + case f: java.lang.Float if f == -0.0f => 0.0f + case _ => value + } + + /** + * Generates code to normalize negative zero to positive zero. + */ + def normalizeZeroCode(dataType: DataType, valueName: String): String = { + dataType match { + case DoubleType => + s"($valueName == -0.0d ? 0.0d : $valueName)" + case FloatType => + s"($valueName == -0.0f ? 0.0f : $valueName)" + case _ => valueName + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1907ec7c23aa6..9f4229b743368 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -3111,4 +3111,38 @@ class CollectionExpressionsSuite a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", null, "q") ) } + + test("SPARK-54918: array operations should normalize -0.0 to 0.0") { + // Double + val ad = Literal.create(Seq(0.0d, -0.0d, 1.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayDistinct(ad), Seq(0.0d, 1.0d)) + + val au1 = Literal.create(Seq(0.0d), ArrayType(DoubleType)) + val au2 = Literal.create(Seq(-0.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayUnion(au1, au2), Seq(0.0d)) + + val ai1 = Literal.create(Seq(0.0d, 1.0d), ArrayType(DoubleType)) + val ai2 = Literal.create(Seq(-0.0d, 2.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayIntersect(ai1, ai2), Seq(0.0d)) + + val ae1 = Literal.create(Seq(0.0d, 1.0d), ArrayType(DoubleType)) + val ae2 = Literal.create(Seq(-0.0d), ArrayType(DoubleType)) + checkEvaluation(ArrayExcept(ae1, ae2), Seq(1.0d)) + + // Float + val adf = Literal.create(Seq(0.0f, -0.0f, 1.0f), ArrayType(FloatType)) + checkEvaluation(ArrayDistinct(adf), Seq(0.0f, 1.0f)) + + val au1f = Literal.create(Seq(0.0f), ArrayType(FloatType)) + val au2f = Literal.create(Seq(-0.0f), ArrayType(FloatType)) + checkEvaluation(ArrayUnion(au1f, au2f), Seq(0.0f)) + + val ai1f = Literal.create(Seq(0.0f, 1.0f), ArrayType(FloatType)) + val ai2f = Literal.create(Seq(-0.0f, 2.0f), ArrayType(FloatType)) + checkEvaluation(ArrayIntersect(ai1f, ai2f), Seq(0.0f)) + + val ae1f = Literal.create(Seq(0.0f, 1.0f), ArrayType(FloatType)) + val ae2f = Literal.create(Seq(-0.0f), ArrayType(FloatType)) + checkEvaluation(ArrayExcept(ae1f, ae2f), Seq(1.0f)) + } }