Skip to content

Commit a1f4c93

Browse files
committed
temp
1 parent 9a452f8 commit a1f4c93

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,29 @@ def add1(x):
13791379
result = empty_df.select(add1("id"))
13801380
self.assertEqual(result.collect(), [])
13811381

1382+
def test_udf_with_collated_string_types(self):
1383+
@udf("string collate fr")
1384+
def my_udf(input_val):
1385+
return "%s - %s" % (type(input_val), input_val)
1386+
1387+
string_types = [
1388+
StringType(),
1389+
StringType("UTF8_BINARY"),
1390+
StringType("UTF8_LCASE"),
1391+
StringType("UNICODE"),
1392+
]
1393+
data = [("hello",)]
1394+
expected = "<class 'str'> - hello"
1395+
1396+
for string_type in string_types:
1397+
schema = StructType([StructField("input_col", string_type, True)])
1398+
df = self.spark.createDataFrame(data, schema=schema)
1399+
df_result = df.select(my_udf(df.input_col).alias("result"))
1400+
row = df_result.collect()[0][0]
1401+
self.assertEqual(row, expected)
1402+
result_type = df_result.schema["result"].dataType
1403+
self.assertEqual(result_type, StringType("fr"))
1404+
13821405

13831406
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13841407
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,6 +3437,40 @@ def eval(self):
34373437
udtf(TestUDTF, returnType=ret_type)().collect()
34383438

34393439

3440+
def test_udtf_with_collated_string_types(self):
3441+
@udtf(
3442+
"out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE, out4 string collate UNICODE"
3443+
)
3444+
class MyUDTF:
3445+
def eval(self, v1, v2, v3, v4):
3446+
yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")
3447+
3448+
schema = StructType(
3449+
[
3450+
StructField("col1", StringType(), True),
3451+
StructField("col2", StringType("UTF8_BINARY"), True),
3452+
StructField("col3", StringType("UTF8_LCASE"), True),
3453+
StructField("col4", StringType("UNICODE"), True),
3454+
]
3455+
)
3456+
df = self.spark.createDataFrame([("hello",) * 4], schema=schema)
3457+
3458+
df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
3459+
result_df = df_out.select("out.*")
3460+
3461+
expected_row = ("hello1", "hello2", "hello3", "hello4")
3462+
self.assertEqual(result_df.collect()[0], expected_row)
3463+
3464+
expected_output_types = [
3465+
StringType(),
3466+
StringType("UTF8_BINARY"),
3467+
StringType("UTF8_LCASE"),
3468+
StringType("UNICODE"),
3469+
]
3470+
for idx, field in enumerate(result_df.schema.fields):
3471+
self.assertEqual(field.dataType, expected_output_types[idx])
3472+
3473+
34403474
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
34413475
@classmethod
34423476
def setUpClass(cls):

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ object EvaluatePython {
7979

8080
case (d: Decimal, _) => d.toJavaBigDecimal
8181

82-
case (s: UTF8String, StringType) => s.toString
82+
case (s: UTF8String, _: StringType) => s.toString
8383

8484
case (other, _) => other
8585
}

0 commit comments

Comments
 (0)