Skip to content

[SPARK-52235][SQL] Add implicit cast to DefaultValue V2 Expressions passed to DSV2 #50959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
DefaultValueExpressionCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ object TypeCoercion extends TypeCoercionBase {
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
DefaultValueExpressionCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,22 @@ import org.apache.spark.sql.catalyst.expressions.{
WindowSpecDefinition
}
import org.apache.spark.sql.catalyst.plans.logical.{
AddColumns,
AlterColumns,
Call,
CreateTable,
Except,
Intersect,
LogicalPlan,
Project,
ReplaceTable,
Union,
Unpivot
}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -81,6 +87,71 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
}
}

/**
* A type coercion rule that implicitly casts default value expression in DDL statements
* to expected types.
*/
object DefaultValueExpressionCoercion extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case createTable @ CreateTable(_, cols, _, _, _) if createTable.resolved &&
cols.exists(_.defaultValue.isDefined) =>
val newCols = cols.map { c =>
c.copy(defaultValue = c.defaultValue.map(d =>
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
d.child,
c.dataType,
"CREATE TABLE",
c.name,
d.originalSQL,
castWiderOnlyLiterals = false))))
}
createTable.copy(columns = newCols)

case replaceTable @ ReplaceTable(_, cols, _, _, _) if replaceTable.resolved &&
cols.exists(_.defaultValue.isDefined) =>
val newCols = cols.map { c =>
c.copy(defaultValue = c.defaultValue.map(d =>
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
d.child,
c.dataType,
"REPLACE TABLE",
c.name,
d.originalSQL,
castWiderOnlyLiterals = false))))
}
replaceTable.copy(columns = newCols)

case addColumns @ AddColumns(_, cols) if addColumns.resolved &&
cols.exists(_.default.isDefined) =>
val newCols = cols.map { c =>
c.copy(default = c.default.map(d =>
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
d.child,
c.dataType,
"ALTER TABLE ADD COLUMNS",
c.colName,
d.originalSQL,
castWiderOnlyLiterals = false))))
}
addColumns.copy(columnsToAdd = newCols)

case alterColumns @ AlterColumns(_, specs) if alterColumns.resolved &&
specs.exists(_.newDefaultExpression.isDefined) =>
val newSpecs = specs.map { c =>
val dataType = c.column.asInstanceOf[ResolvedFieldName].field.dataType
c.copy(newDefaultExpression = c.newDefaultExpression.map(d =>
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
d.child,
dataType,
"ALTER TABLE ALTER COLUMN",
c.column.name.quoted,
d.originalSQL,
castWiderOnlyLiterals = false))))
}
alterColumns.copy(specs = newSpecs)
}
}

/**
* Widens the data types of the [[Unpivot]] values.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,20 +432,20 @@ object ResolveDefaultColumns extends QueryErrorsBase
targetType: DataType,
colName: String): Option[Expression] = {
expr match {
case l: Literal if !Seq(targetType, l.dataType).exists(_ match {
case e if e.foldable && !Seq(targetType, e.dataType).exists(_ match {
case _: BooleanType | _: ArrayType | _: StructType | _: MapType => true
case _ => false
}) =>
val casted = Cast(l, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
val casted = Cast(e, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
try {
Option(casted.eval(EmptyRow)).map(Literal(_, targetType))
} catch {
case e @ ( _: SparkThrowable | _: RuntimeException) =>
logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, l)}' " +
case ex @ ( _: SparkThrowable | _: RuntimeException) =>
Copy link
Contributor

@cloud-fan cloud-fan May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be safe, can we match all exceptions here? If anything gets wrong we just fail and say the type is incompatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks

logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, e)}' " +
log"for column ${MDC(COLUMN_NAME, colName)} " +
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, l.dataType)} " +
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, e.dataType)} " +
log"to ${MDC(COLUMN_DATA_TYPE_TARGET, targetType)} " +
log"due to ${MDC(ERROR, e.getMessage)}", e)
log"due to ${MDC(ERROR, ex.getMessage)}", ex)
None
}
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan
import org.apache.spark.sql.execution.datasources.v2.{AlterTableExec, CreateTableExec, DataSourceV2Relation, ReplaceTableExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, IntegerType, StringType}
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, DoubleType, IntegerType, StringType, TimestampType}
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -498,43 +498,32 @@ class DataSourceV2DataFrameSuite
|""".stripMargin)

val alterExecCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT (123 + 56)")
}
checkDefaultValue(
alterExecCol1.changes.collect {
case u: UpdateColumnDefaultValue => u
}.head,
new DefaultValue(
"(123 + 56)",
new GeneralScalarExpression(
"+",
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))))

val alterExecCol2 = executeAndKeepPhysicalPlan[AlterTableExec] {
sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT ('r' || 'l')")
}
checkDefaultValue(
alterExecCol2.changes.collect {
case u: UpdateColumnDefaultValue => u
}.head,
new DefaultValue(
"('r' || 'l')",
new GeneralScalarExpression(
"CONCAT",
Array(
LiteralValue(UTF8String.fromString("r"), StringType),
LiteralValue(UTF8String.fromString("l"), StringType)))))

val alterExecCol3 = executeAndKeepPhysicalPlan[AlterTableExec] {
sql(s"ALTER TABLE $tableName ALTER COLUMN active SET DEFAULT CAST(0 AS BOOLEAN)")
sql(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some test cleanup, I didnt realize multiple ALTER COLUMN can be the same statement, so it was unnecessarily long

s"""
|ALTER TABLE $tableName ALTER COLUMN
| salary SET DEFAULT (123 + 56),
| dep SET DEFAULT ('r' || 'l'),
| active SET DEFAULT CAST(0 AS BOOLEAN)
|""".stripMargin)
}
checkDefaultValue(
alterExecCol3.changes.collect {
case u: UpdateColumnDefaultValue => u
}.head,
new DefaultValue(
"CAST(0 AS BOOLEAN)",
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType)))
checkDefaultValues(
alterExecCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
Array(
new DefaultValue(
"(123 + 56)",
new GeneralScalarExpression(
"+",
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))),
new DefaultValue(
"('r' || 'l')",
new GeneralScalarExpression(
"CONCAT",
Array(
LiteralValue(UTF8String.fromString("r"), StringType),
LiteralValue(UTF8String.fromString("l"), StringType)))),
new DefaultValue(
"CAST(0 AS BOOLEAN)",
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType))))
}
}
}
Expand Down Expand Up @@ -666,13 +655,9 @@ class DataSourceV2DataFrameSuite
sql(s"ALTER TABLE $tableName ALTER COLUMN cat SET DEFAULT current_catalog()")
}

checkDefaultValue(
alterExec.changes.collect {
case u: UpdateColumnDefaultValue => u
}.head,
new DefaultValue(
"current_catalog()",
null /* No V2 Expression */))
checkDefaultValues(
alterExec.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
Array(new DefaultValue("current_catalog()", null /* No V2 Expression */)))

val df1 = Seq(1).toDF("dummy")
df1.writeTo(tableName).append()
Expand All @@ -683,6 +668,109 @@ class DataSourceV2DataFrameSuite
}
}

test("create/replace table default value expression should have a cast") {
Copy link
Contributor

@LuciferYang LuciferYang Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for create/replace table default value expression should have a cast and alter table default value expression should have a cast will fail in Non ANSI mode:

image

We can execute the following command to reproduce the issue locally:

SPARK_ANSI_SQL_MODE=false build/sbt clean "sql/testOnly org.apache.spark.sql.connector.DataSourceV2DataFrameSuite"
[info] - create/replace table default value expression should have a cast *** FAILED *** (25 milliseconds)
[info]   ColumnDefaultValue{sql=(1 + 1), expression=null, value=2.0} did not equal ColumnDefaultValue{sql=(1 + 1), expression=CAST(1 + 1 AS double), value=2.0} Default value mismatch for column 'col3': expected ColumnDefaultValue{sql=(1 + 1), expression=CAST(1 + 1 AS double), value=2.0} but found ColumnDefaultValue{sql=(1 + 1), expression=null, value=2.0} (DataSourceV2DataFrameSuite.scala:870)
[info]   org.scalatest.exceptions.TestFailedException:
...

[info] - alter table default value expression should have a cast *** FAILED *** (18 milliseconds)
[info]   DefaultValue{sql=(1 + 1), expression=null} did not equal DefaultValue{sql=(1 + 1), expression=CAST(1 + 1 AS double)} Default value mismatch for column 'org.apache.spark.sql.connector.catalog.TableChange$UpdateColumnDefaultValue@99efd8cb': expected DefaultValue{sql=(1 + 1), expression=CAST(1 + 1 AS double)} but found DefaultValue{sql=(1 + 1), expression=null} (DataSourceV2DataFrameSuite.scala:898)
[info]   org.scalatest.exceptions.TestFailedException:

Do you have time to take a look? @szehon-ho
also cc @cloud-fan

val tableName = "testcat.ns1.ns2.tbl"
withTable(tableName) {

val createExec = executeAndKeepPhysicalPlan[CreateTableExec] {
sql(
s"""
|CREATE TABLE $tableName (
| col1 int,
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
| col3 double DEFAULT 1)
|""".stripMargin)
}
checkDefaultValues(
createExec.columns,
Array(
null,
new ColumnDefaultValue(
"'2018-11-17 13:33:33'",
new LiteralValue(1542490413000000L, TimestampType),
new LiteralValue(1542490413000000L, TimestampType)),
new ColumnDefaultValue(
"1",
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
LiteralValue(1.0, DoubleType))))

val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] {
sql(
s"""
|REPLACE TABLE $tableName (
| col1 int,
| col2 timestamp DEFAULT '2022-02-23 05:55:55',
| col3 double DEFAULT (1 + 1))
|""".stripMargin)
}
checkDefaultValues(
replaceExec.columns,
Array(
null,
new ColumnDefaultValue(
"'2022-02-23 05:55:55'",
LiteralValue(1645624555000000L, TimestampType),
LiteralValue(1645624555000000L, TimestampType)),
new ColumnDefaultValue(
"(1 + 1)",
new V2Cast(
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
LiteralValue(1, IntegerType))),
IntegerType,
DoubleType),
LiteralValue(2.0, DoubleType))))
}
}

test("alter table default value expression should have a cast") {
val tableName = "testcat.ns1.ns2.tbl"
withTable(tableName) {

sql(s"CREATE TABLE $tableName (col1 int) using foo")
val alterExec = executeAndKeepPhysicalPlan[AlterTableExec] {
sql(
s"""
|ALTER TABLE $tableName ADD COLUMNS (
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
| col3 double DEFAULT 1)
|""".stripMargin)
}

checkDefaultValues(
alterExec.changes.map(_.asInstanceOf[AddColumn]).toArray,
Array(
new ColumnDefaultValue(
"'2018-11-17 13:33:33'",
LiteralValue(1542490413000000L, TimestampType),
LiteralValue(1542490413000000L, TimestampType)),
new ColumnDefaultValue(
"1",
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
LiteralValue(1.0, DoubleType))))

val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
sql(
s"""
|ALTER TABLE $tableName ALTER COLUMN
| col2 SET DEFAULT '2022-02-23 05:55:55',
| col3 SET DEFAULT (1 + 1)
|""".stripMargin)
}
checkDefaultValues(
alterCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
Array(
new DefaultValue("'2022-02-23 05:55:55'",
LiteralValue(1645624555000000L, TimestampType)),
new DefaultValue(
"(1 + 1)",
new V2Cast(
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
LiteralValue(1, IntegerType))),
IntegerType,
DoubleType))))
}
}

private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = {
val qe = withQueryExecutionsCaptured(spark) {
func
Expand Down Expand Up @@ -718,13 +806,18 @@ class DataSourceV2DataFrameSuite
}
}

private def checkDefaultValue(
column: UpdateColumnDefaultValue,
expectedDefault: DefaultValue): Unit = {
assert(
column.newCurrentDefault() == expectedDefault,
s"Default value mismatch for column '${column.toString}': " +
s"expected $expectedDefault but found ${column.newCurrentDefault()}")
private def checkDefaultValues(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could define new checkDefaultValues using checkDefaultValue to minimize changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think i prefer to use the new method in all the place, as it makes the test significantly shorter (no need to check each change individually)

columns: Array[UpdateColumnDefaultValue],
expectedDefaultValues: Array[DefaultValue]): Unit = {
assert(columns.length == expectedDefaultValues.length)

columns.zip(expectedDefaultValues).foreach {
case (column, expectedDefault) =>
assert(
column.newCurrentDefault() == expectedDefault,
s"Default value mismatch for column '${column.toString}': " +
s"expected $expectedDefault but found ${column.newCurrentDefault}")
}
}

private def checkDropDefaultValue(
Expand Down