Skip to content

Commit 7fec8b9

Browse files
committed
temp
1 parent 7db6c82 commit 7fec8b9

File tree

8 files changed

+97
-20
lines changed

8 files changed

+97
-20
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,29 @@ def add1(x):
13791379
result = empty_df.select(add1("id"))
13801380
self.assertEqual(result.collect(), [])
13811381

1382+
def test_udf_with_collated_string_types(self):
1383+
@udf("string collate fr")
1384+
def my_udf(input_val):
1385+
return "%s - %s" % (type(input_val), input_val)
1386+
1387+
string_types = [
1388+
StringType(),
1389+
StringType("UTF8_BINARY"),
1390+
StringType("UTF8_LCASE"),
1391+
StringType("UNICODE"),
1392+
]
1393+
data = [("hello",)]
1394+
expected = "<class 'str'> - hello"
1395+
1396+
for string_type in string_types:
1397+
schema = StructType([StructField("input_col", string_type, True)])
1398+
df = self.spark.createDataFrame(data, schema=schema)
1399+
df_result = df.select(my_udf(df.input_col).alias("result"))
1400+
row = df_result.collect()[0][0]
1401+
self.assertEqual(row, expected)
1402+
result_type = df_result.schema["result"].dataType
1403+
self.assertEqual(result_type, StringType("fr"))
1404+
13821405

13831406
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13841407
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,6 +3490,41 @@ def eval(self):
34903490
udtf(TestUDTF, returnType=ret_type)().collect()
34913491

34923492

3493+
def test_udtf_with_collated_string_types(self):
3494+
@udtf(
3495+
"out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE,"
3496+
" out4 string collate UNICODE"
3497+
)
3498+
class MyUDTF:
3499+
def eval(self, v1, v2, v3, v4):
3500+
yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")
3501+
3502+
schema = StructType(
3503+
[
3504+
StructField("col1", StringType(), True),
3505+
StructField("col2", StringType("UTF8_BINARY"), True),
3506+
StructField("col3", StringType("UTF8_LCASE"), True),
3507+
StructField("col4", StringType("UNICODE"), True),
3508+
]
3509+
)
3510+
df = self.spark.createDataFrame([("hello",) * 4], schema=schema)
3511+
3512+
df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
3513+
result_df = df_out.select("out.*")
3514+
3515+
expected_row = ("hello1", "hello2", "hello3", "hello4")
3516+
self.assertEqual(result_df.collect()[0], expected_row)
3517+
3518+
expected_output_types = [
3519+
StringType(),
3520+
StringType("UTF8_BINARY"),
3521+
StringType("UTF8_LCASE"),
3522+
StringType("UNICODE"),
3523+
]
3524+
for idx, field in enumerate(result_df.schema.fields):
3525+
self.assertEqual(field.dataType, expected_output_types[idx])
3526+
3527+
34933528
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
34943529
@classmethod
34953530
def setUpClass(cls):

sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -448,18 +448,35 @@ object DataType {
448448
}
449449

450450
/**
451-
* Check if `from` is equal to `to` type except for collations, which are checked to be
452-
* compatible so that data of type `from` can be interpreted as of type `to`.
451+
* Compares two data types, ignoring compatible collation of StringType.
452+
* If `checkComplexTypes` is true, it will also ignore collations for nested types.
453453
*/
454-
private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = {
455-
(from, to) match {
456-
// String types with possibly different collations are compatible.
457-
case (a: StringType, b: StringType) => a.constraint == b.constraint
454+
private[sql] def equalsIgnoreCompatibleCollation(
455+
from: DataType, to: DataType, checkComplexTypes: Boolean = true): Boolean = {
456+
def transform: PartialFunction[DataType, DataType] = {
457+
case dt @ (_: CharType | _: VarcharType) => dt
458+
case _: StringType => StringType
459+
}
458460

459-
case (fromDataType, toDataType) => fromDataType == toDataType
461+
if (checkComplexTypes) {
462+
from.transformRecursively(transform) == to.transformRecursively(transform)
463+
} else {
464+
(from, to) match {
465+
case (a: StringType, b: StringType) => a.constraint == b.constraint
466+
467+
case (fromDataType, toDataType) => fromDataType == toDataType
468+
}
460469
}
461470
}
462471

472+
private[sql] def equalsIgnoreCompatibleCollation(
473+
from: Seq[DataType], to: Seq[DataType]): Boolean = {
474+
from.length == to.length &&
475+
from.zip(to).forall { case (fromDataType, toDataType) =>
476+
equalsIgnoreCompatibleCollation(fromDataType, toDataType)
477+
}
478+
}
479+
463480
/**
464481
* Returns true if the two data types share the same "shape", i.e. the types are the same, but
465482
* the field names don't need to be the same.

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ class DataTypeSuite extends SparkFunSuite {
880880
checkEqualsIgnoreCompatibleCollation(
881881
ArrayType(StringType),
882882
ArrayType(StringType("UTF8_LCASE")),
883-
expected = false
883+
expected = true
884884
)
885885
checkEqualsIgnoreCompatibleCollation(
886886
ArrayType(StringType),
@@ -890,7 +890,7 @@ class DataTypeSuite extends SparkFunSuite {
890890
checkEqualsIgnoreCompatibleCollation(
891891
ArrayType(ArrayType(StringType)),
892892
ArrayType(ArrayType(StringType("UTF8_LCASE"))),
893-
expected = false
893+
expected = true
894894
)
895895
checkEqualsIgnoreCompatibleCollation(
896896
ArrayType(ArrayType(StringType)),
@@ -915,12 +915,12 @@ class DataTypeSuite extends SparkFunSuite {
915915
checkEqualsIgnoreCompatibleCollation(
916916
MapType(StringType, StringType),
917917
MapType(StringType, StringType("UTF8_LCASE")),
918-
expected = false
918+
expected = true
919919
)
920920
checkEqualsIgnoreCompatibleCollation(
921921
MapType(StringType("UTF8_LCASE"), StringType),
922922
MapType(StringType, StringType),
923-
expected = false
923+
expected = true
924924
)
925925
checkEqualsIgnoreCompatibleCollation(
926926
MapType(StringType("UTF8_LCASE"), StringType),
@@ -945,7 +945,7 @@ class DataTypeSuite extends SparkFunSuite {
945945
checkEqualsIgnoreCompatibleCollation(
946946
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
947947
MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))),
948-
expected = false
948+
expected = true
949949
)
950950
checkEqualsIgnoreCompatibleCollation(
951951
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
@@ -970,7 +970,7 @@ class DataTypeSuite extends SparkFunSuite {
970970
checkEqualsIgnoreCompatibleCollation(
971971
MapType(ArrayType(StringType), IntegerType),
972972
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
973-
expected = false
973+
expected = true
974974
)
975975
checkEqualsIgnoreCompatibleCollation(
976976
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
@@ -1000,7 +1000,7 @@ class DataTypeSuite extends SparkFunSuite {
10001000
checkEqualsIgnoreCompatibleCollation(
10011001
StructType(StructField("a", StringType) :: Nil),
10021002
StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil),
1003-
expected = false
1003+
expected = true
10041004
)
10051005
checkEqualsIgnoreCompatibleCollation(
10061006
StructType(StructField("a", StringType) :: Nil),
@@ -1025,7 +1025,7 @@ class DataTypeSuite extends SparkFunSuite {
10251025
checkEqualsIgnoreCompatibleCollation(
10261026
StructType(StructField("a", ArrayType(StringType)) :: Nil),
10271027
StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil),
1028-
expected = false
1028+
expected = true
10291029
)
10301030
checkEqualsIgnoreCompatibleCollation(
10311031
StructType(StructField("a", ArrayType(StringType)) :: Nil),
@@ -1050,7 +1050,7 @@ class DataTypeSuite extends SparkFunSuite {
10501050
checkEqualsIgnoreCompatibleCollation(
10511051
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
10521052
StructType(StructField("a", MapType(StringType("UTF8_LCASE"), IntegerType)) :: Nil),
1053-
expected = false
1053+
expected = true
10541054
)
10551055
checkEqualsIgnoreCompatibleCollation(
10561056
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ case class AlterTableChangeColumnCommand(
465465
// when altering column. Only changes in collation of data type or its nested types (recursively)
466466
// are allowed.
467467
private def canEvolveType(from: StructField, to: StructField): Boolean = {
468-
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType)
468+
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType, checkComplexTypes = false)
469469
}
470470
}
471471

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
2828
import org.apache.spark.sql.execution.metric.SQLMetric
2929
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
3030
import org.apache.spark.sql.types.{StructType, UserDefinedType}
31+
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
3132

3233
/**
3334
* Grouped a iterator into batches.
@@ -128,7 +129,7 @@ class ArrowEvalPythonEvaluatorFactory(
128129

129130
columnarBatchIter.flatMap { batch =>
130131
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
131-
if (outputTypes != actualDataTypes) {
132+
if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
132133
throw QueryExecutionErrors.arrowDataTypeMismatchError(
133134
"pandas_udf()", outputTypes, actualDataTypes)
134135
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
2626
import org.apache.spark.sql.execution.SparkPlan
2727
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
2828
import org.apache.spark.sql.types.{StructType, UserDefinedType}
29+
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
2930
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
3031

3132
/**
@@ -84,7 +85,7 @@ case class ArrowEvalPythonUDTFExec(
8485

8586
val actualDataTypes = (0 until flattenedBatch.numCols()).map(
8687
i => flattenedBatch.column(i).dataType())
87-
if (outputTypes != actualDataTypes) {
88+
if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
8889
throw QueryExecutionErrors.arrowDataTypeMismatchError(
8990
"Python UDTF", outputTypes, actualDataTypes)
9091
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ object EvaluatePython {
7979

8080
case (d: Decimal, _) => d.toJavaBigDecimal
8181

82-
case (s: UTF8String, StringType) => s.toString
82+
case (s: UTF8String, _: StringType) => s.toString
8383

8484
case (other, _) => other
8585
}

0 commit comments

Comments
 (0)