Skip to content

Commit 2584dab

Browse files
committed
temp
1 parent 9a452f8 commit 2584dab

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,22 @@ def add1(x):
13791379
result = empty_df.select(add1("id"))
13801380
self.assertEqual(result.collect(), [])
13811381

1382+
def test_udf_with_collated_string_types(self):
1383+
@udf("string")
1384+
def my_udf(input_val):
1385+
return "%s - %s" % (type(input_val), input_val)
1386+
1387+
string_types = [StringType(), StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), StringType("UNICODE")]
1388+
data = [("hello",)]
1389+
expected = "<class 'str'> - hello"
1390+
1391+
for string_type in string_types:
1392+
schema = StructType([StructField("input_col", string_type, True)])
1393+
df = self.spark.createDataFrame(data, schema=schema)
1394+
row = df.select(my_udf(df.input_col)).collect()[0][0]
1395+
self.assertEqual(row, expected)
1396+
1397+
13821398

13831399
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13841400
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3436,6 +3436,37 @@ def eval(self):
34363436
with self.assertRaisesRegex(PythonException, "UDTF_ARROW_TYPE_CONVERSION_ERROR"):
34373437
udtf(TestUDTF, returnType=ret_type)().collect()
34383438

3439+
def test_udtf_with_collated_string_types(self):
3440+
class MyUDTF:
3441+
def eval(self, val1, val2, val3, val4):
3442+
yield (val1 + "1", val2 + "2", val3 + "3", val4 + "4")
3443+
3444+
my_udtf = udtf(
3445+
MyUDTF,
3446+
returnType=StructType([
3447+
StructField("out1", StringType()),
3448+
StructField("out2", StringType()),
3449+
StructField("out3", StringType()),
3450+
StructField("out4", StringType()),
3451+
])
3452+
)
3453+
3454+
schema = StructType([
3455+
StructField("col1", StringType(), True),
3456+
StructField("col2", StringType("UTF8_BINARY"), True),
3457+
StructField("col3", StringType("UTF8_LCASE"), True),
3458+
StructField("col4", StringType("UNICODE"), True)
3459+
])
3460+
3461+
data = [("hello", "hello", "hello", "hello")]
3462+
df = self.spark.createDataFrame(data, schema=schema)
3463+
3464+
result_row = df.select(my_udtf(df.col1, df.col2, df.col3, df.col4)).collect()[0]
3465+
3466+
expected = ("hello1", "hello2", "hello3", "hello4")
3467+
self.assertEqual(result_row, expected)
3468+
3469+
34393470

34403471
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
34413472
@classmethod

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ object EvaluatePython {
7979

8080
case (d: Decimal, _) => d.toJavaBigDecimal
8181

82-
case (s: UTF8String, StringType) => s.toString
82+
case (s: UTF8String, _: StringType) => s.toString
8383

8484
case (other, _) => other
8585
}

0 commit comments

Comments
 (0)