Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,29 @@ def add1(x):
result = empty_df.select(add1("id"))
self.assertEqual(result.collect(), [])

def test_udf_with_collated_string_types(self):
@udf("string collate fr")
def my_udf(input_val):
return "%s - %s" % (type(input_val), input_val)

string_types = [
StringType(),
StringType("UTF8_BINARY"),
StringType("UTF8_LCASE"),
StringType("UNICODE"),
]
data = [("hello",)]
expected = "<class 'str'> - hello"

for string_type in string_types:
schema = StructType([StructField("input_col", string_type, True)])
df = self.spark.createDataFrame(data, schema=schema)
df_result = df.select(my_udf(df.input_col).alias("result"))
row = df_result.collect()[0][0]
self.assertEqual(row, expected)
result_type = df_result.schema["result"].dataType
self.assertEqual(result_type, StringType("fr"))


class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3490,6 +3490,41 @@ def eval(self):
udtf(TestUDTF, returnType=ret_type)().collect()


def test_udtf_with_collated_string_types(self):
Copy link
Contributor

@zhengruifeng zhengruifeng Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilicmarkodb the indent here is wrong, this test is actually skipped. It should be put into a Mixin class like BaseUDTFTestsMixin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#52001

I opened a PR to fix this. I’ll finish it and tag you for review once the CI is green.

@udtf(
"out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE,"
" out4 string collate UNICODE"
)
class MyUDTF:
def eval(self, v1, v2, v3, v4):
yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")

schema = StructType(
[
StructField("col1", StringType(), True),
StructField("col2", StringType("UTF8_BINARY"), True),
StructField("col3", StringType("UTF8_LCASE"), True),
StructField("col4", StringType("UNICODE"), True),
]
)
df = self.spark.createDataFrame([("hello",) * 4], schema=schema)

df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this query work? I guess it should be a lateralJoin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn’t. I just didn’t realize that, since the test wasn't executed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL #51688

result_df = df_out.select("out.*")

expected_row = ("hello1", "hello2", "hello3", "hello4")
self.assertEqual(result_df.collect()[0], expected_row)

expected_output_types = [
StringType(),
StringType("UTF8_BINARY"),
StringType("UTF8_LCASE"),
StringType("UNICODE"),
]
for idx, field in enumerate(result_df.schema.fields):
self.assertEqual(field.dataType, expected_output_types[idx])


class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
Expand Down
34 changes: 27 additions & 7 deletions sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,35 @@ object DataType {
}

/**
* Check if `from` is equal to `to` type except for collations, which are checked to be
* compatible so that data of type `from` can be interpreted as of type `to`.
* Compares two data types, ignoring compatible collation of StringType. If `checkComplexTypes`
* is true, it will also ignore collations for nested types.
*/
private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = {
(from, to) match {
// String types with possibly different collations are compatible.
case (a: StringType, b: StringType) => a.constraint == b.constraint
private[sql] def equalsIgnoreCompatibleCollation(
from: DataType,
to: DataType,
checkComplexTypes: Boolean = true): Boolean = {
def transform: PartialFunction[DataType, DataType] = {
case dt @ (_: CharType | _: VarcharType) => dt
case _: StringType => StringType
}

case (fromDataType, toDataType) => fromDataType == toDataType
if (checkComplexTypes) {
from.transformRecursively(transform) == to.transformRecursively(transform)
} else {
(from, to) match {
case (a: StringType, b: StringType) => a.constraint == b.constraint

case (fromDataType, toDataType) => fromDataType == toDataType
}
}
}

private[sql] def equalsIgnoreCompatibleCollation(
from: Seq[DataType],
to: Seq[DataType]): Boolean = {
from.length == to.length &&
from.zip(to).forall { case (fromDataType, toDataType) =>
equalsIgnoreCompatibleCollation(fromDataType, toDataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType),
ArrayType(StringType("UTF8_LCASE")),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType),
Expand All @@ -890,7 +890,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
ArrayType(ArrayType(StringType)),
ArrayType(ArrayType(StringType("UTF8_LCASE"))),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
ArrayType(ArrayType(StringType)),
Expand All @@ -915,12 +915,12 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
MapType(StringType, StringType),
MapType(StringType, StringType("UTF8_LCASE")),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_LCASE"), StringType),
MapType(StringType, StringType),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_LCASE"), StringType),
Expand All @@ -945,7 +945,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
Expand All @@ -970,7 +970,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
MapType(ArrayType(StringType), IntegerType),
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
Expand Down Expand Up @@ -1000,7 +1000,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", StringType) :: Nil),
StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", StringType) :: Nil),
Expand All @@ -1025,7 +1025,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", ArrayType(StringType)) :: Nil),
StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", ArrayType(StringType)) :: Nil),
Expand All @@ -1050,7 +1050,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
StructType(StructField("a", MapType(StringType("UTF8_LCASE"), IntegerType)) :: Nil),
expected = false
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ case class AlterTableChangeColumnCommand(
// when altering column. Only changes in collation of data type or its nested types (recursively)
// are allowed.
private def canEvolveType(from: StructField, to: StructField): Boolean = {
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType)
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType, checkComplexTypes = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{StructType, UserDefinedType}
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation

/**
* Grouped a iterator into batches.
Expand Down Expand Up @@ -128,7 +129,7 @@ class ArrowEvalPythonEvaluatorFactory(

columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
if (outputTypes != actualDataTypes) {
if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
throw QueryExecutionErrors.arrowDataTypeMismatchError(
"pandas_udf()", outputTypes, actualDataTypes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{StructType, UserDefinedType}
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}

/**
Expand Down Expand Up @@ -84,7 +85,7 @@ case class ArrowEvalPythonUDTFExec(

val actualDataTypes = (0 until flattenedBatch.numCols()).map(
i => flattenedBatch.column(i).dataType())
if (outputTypes != actualDataTypes) {
if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
throw QueryExecutionErrors.arrowDataTypeMismatchError(
"Python UDTF", outputTypes, actualDataTypes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object EvaluatePython {

case (d: Decimal, _) => d.toJavaBigDecimal

case (s: UTF8String, StringType) => s.toString
case (s: UTF8String, _: StringType) => s.toString

case (other, _) => other
}
Expand Down