Skip to content

Commit 9a452f8

Browse files
committed
[SPARK-52945][SQL][TESTS] Split CastSuiteBase#checkInvalidCastFromNumericType into three methods and guarantee assertions are valid
### What changes were proposed in this pull request? Due to the absence of `assert` statements, the `CastSuiteBase#checkInvalidCastFromNumericType` method previously performed no assertion checks. Additionally, since `checkInvalidCastFromNumericType` had significant variations in target type validation logic across different `EvalMode` contexts, this pr refactors the method into three specialized methods to ensure robust assertion enforcement: - `checkInvalidCastFromNumericTypeToDateType` - `checkInvalidCastFromNumericTypeToTimestampNTZType` - `checkInvalidCastFromNumericTypeToBinaryType` ### Why are the changes needed? To address the missing assertion validation in `CastSuiteBase`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #51668 from LuciferYang/SPARK-52945. Authored-by: yangjie01 <[email protected]> Signed-off-by: yangjie01 <[email protected]>
1 parent 2817654 commit 9a452f8

File tree

3 files changed

+77
-59
lines changed

3 files changed

+77
-59
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -545,61 +545,42 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
545545
checkCast("0", false)
546546
}
547547

548-
protected def checkInvalidCastFromNumericType(to: DataType): Unit = {
549-
cast(1.toByte, to).checkInputDataTypes() ==
550-
DataTypeMismatch(
551-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
552-
messageParameters = Map(
553-
"srcType" -> toSQLType(Literal(1.toByte).dataType),
554-
"targetType" -> toSQLType(to),
555-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
556-
)
557-
)
558-
cast(1.toShort, to).checkInputDataTypes() ==
559-
DataTypeMismatch(
560-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
561-
messageParameters = Map(
562-
"srcType" -> toSQLType(Literal(1.toShort).dataType),
563-
"targetType" -> toSQLType(to),
564-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
565-
)
566-
)
567-
cast(1, to).checkInputDataTypes() ==
568-
DataTypeMismatch(
569-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
570-
messageParameters = Map(
571-
"srcType" -> toSQLType(Literal(1).dataType),
572-
"targetType" -> toSQLType(to),
573-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
574-
)
575-
)
576-
cast(1L, to).checkInputDataTypes() ==
577-
DataTypeMismatch(
578-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
579-
messageParameters = Map(
580-
"srcType" -> toSQLType(Literal(1L).dataType),
581-
"targetType" -> toSQLType(to),
582-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
583-
)
584-
)
585-
cast(1.0.toFloat, to).checkInputDataTypes() ==
586-
DataTypeMismatch(
587-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
588-
messageParameters = Map(
589-
"srcType" -> toSQLType(Literal(1.0.toFloat).dataType),
590-
"targetType" -> toSQLType(to),
591-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
592-
)
593-
)
594-
cast(1.0, to).checkInputDataTypes() ==
595-
DataTypeMismatch(
596-
errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
597-
messageParameters = Map(
598-
"srcType" -> toSQLType(Literal(1.0).dataType),
599-
"targetType" -> toSQLType(to),
600-
"functionNames" -> "`DATE_FROM_UNIX_DATE`"
601-
)
602-
)
548+
protected def createCastMismatch(
549+
srcType: DataType,
550+
targetType: DataType,
551+
errorSubClass: String,
552+
extraParams: Map[String, String] = Map.empty): DataTypeMismatch = {
553+
val baseParams = Map(
554+
"srcType" -> toSQLType(srcType),
555+
"targetType" -> toSQLType(targetType)
556+
)
557+
DataTypeMismatch(errorSubClass, baseParams ++ extraParams)
558+
}
559+
560+
protected def checkInvalidCastFromNumericTypeToDateType(): Unit = {
561+
val errorSubClass = if (evalMode == EvalMode.LEGACY) {
562+
"CAST_WITHOUT_SUGGESTION"
563+
} else {
564+
"CAST_WITH_FUNC_SUGGESTION"
565+
}
566+
val funcParams = if (evalMode == EvalMode.LEGACY) {
567+
Map.empty[String, String]
568+
} else {
569+
Map("functionNames" -> "`DATE_FROM_UNIX_DATE`")
570+
}
571+
Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
572+
val expectedError =
573+
createCastMismatch(Literal(testValue).dataType, DateType, errorSubClass, funcParams)
574+
assert(cast(testValue, DateType).checkInputDataTypes() == expectedError)
575+
}
576+
}
577+
protected def checkInvalidCastFromNumericTypeToTimestampNTZType(): Unit = {
578+
// All numeric types: `CAST_WITHOUT_SUGGESTION`
579+
Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
580+
val expectedError =
581+
createCastMismatch(Literal(testValue).dataType, TimestampNTZType, "CAST_WITHOUT_SUGGESTION")
582+
assert(cast(testValue, TimestampNTZType).checkInputDataTypes() == expectedError)
583+
}
603584
}
604585

605586
test("SPARK-16729 type checking for casting to date type") {
@@ -614,7 +595,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
614595
)
615596
)
616597
)
617-
checkInvalidCastFromNumericType(DateType)
598+
checkInvalidCastFromNumericTypeToDateType()
618599
}
619600

620601
test("SPARK-20302 cast with same structure") {
@@ -998,7 +979,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
998979

999980
test("disallow type conversions between Numeric types and Timestamp without time zone type") {
1000981
import DataTypeTestUtils.numericTypes
1001-
checkInvalidCastFromNumericType(TimestampNTZType)
982+
checkInvalidCastFromNumericTypeToTimestampNTZType()
1002983
verifyCastFailure(
1003984
cast(Literal(0L), TimestampNTZType),
1004985
DataTypeMismatch(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
2929
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
3030
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
3131
import org.apache.spark.sql.errors.QueryErrorsBase
32+
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.types._
3334
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
3435

@@ -39,6 +40,33 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
3940

4041
override def evalMode: EvalMode.Value = EvalMode.ANSI
4142

43+
protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = {
44+
def checkNumericTypeCast(
45+
testValue: Any,
46+
srcType: DataType,
47+
to: DataType,
48+
expectedErrorClass: String,
49+
extraParams: Map[String, String] = Map.empty): Unit = {
50+
val expectedError = createCastMismatch(srcType, to, expectedErrorClass, extraParams)
51+
assert(cast(testValue, to).checkInputDataTypes() == expectedError)
52+
}
53+
54+
// Integer types: suggest config change
55+
val configParams = Map(
56+
"config" -> toSQLConf(SQLConf.ANSI_ENABLED.key),
57+
"configVal" -> toSQLValue("false", StringType)
58+
)
59+
checkNumericTypeCast(1.toByte, ByteType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams)
60+
checkNumericTypeCast(
61+
1.toShort, ShortType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams)
62+
checkNumericTypeCast(1, IntegerType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams)
63+
checkNumericTypeCast(1L, LongType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams)
64+
65+
// Floating types: no suggestion
66+
checkNumericTypeCast(1.0.toFloat, FloatType, BinaryType, "CAST_WITHOUT_SUGGESTION")
67+
checkNumericTypeCast(1.0, DoubleType, BinaryType, "CAST_WITHOUT_SUGGESTION")
68+
}
69+
4270
private def isTryCast = evalMode == EvalMode.TRY
4371

4472
private def testIntMaxAndMin(dt: DataType): Unit = {
@@ -142,7 +170,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
142170

143171
test("ANSI mode: disallow type conversions between Numeric types and Date type") {
144172
import DataTypeTestUtils.numericTypes
145-
checkInvalidCastFromNumericType(DateType)
173+
checkInvalidCastFromNumericTypeToDateType()
146174
verifyCastFailure(
147175
cast(Literal(0L), DateType),
148176
DataTypeMismatch(
@@ -168,7 +196,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
168196

169197
test("ANSI mode: disallow type conversions between Numeric types and Binary type") {
170198
import DataTypeTestUtils.numericTypes
171-
checkInvalidCastFromNumericType(BinaryType)
199+
checkInvalidCastFromNumericTypeToBinaryType()
172200
val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
173201
numericTypes.foreach { numericType =>
174202
assert(cast(binaryLiteral, numericType).checkInputDataTypes() ==

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ class TryCastSuite extends CastWithAnsiOnSuite {
6161
checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
6262
}
6363

64+
override protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = {
65+
// All numeric types: `CAST_WITHOUT_SUGGESTION`
66+
Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
67+
val expectedError =
68+
createCastMismatch(Literal(testValue).dataType, BinaryType, "CAST_WITHOUT_SUGGESTION")
69+
assert(cast(testValue, BinaryType).checkInputDataTypes() == expectedError)
70+
}
71+
}
72+
6473
test("print string") {
6574
assert(cast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
6675
assert(cast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)")

0 commit comments

Comments
 (0)