Skip to content

Commit 2f1bee5

Browse files
committed
temp
1 parent 9a452f8 commit 2f1bee5

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,26 @@ def add1(x):
13791379
result = empty_df.select(add1("id"))
13801380
self.assertEqual(result.collect(), [])
13811381

1382+
def test_udf_with_collated_string_types(self):
1383+
@udf("string")
1384+
def my_udf(input_val):
1385+
return "%s - %s" % (type(input_val), input_val)
1386+
1387+
string_types = [
1388+
StringType(),
1389+
StringType("UTF8_BINARY"),
1390+
StringType("UTF8_LCASE"),
1391+
StringType("UNICODE"),
1392+
]
1393+
data = [("hello",)]
1394+
expected = "<class 'str'> - hello"
1395+
1396+
for string_type in string_types:
1397+
schema = StructType([StructField("input_col", string_type, True)])
1398+
df = self.spark.createDataFrame(data, schema=schema)
1399+
row = df.select(my_udf(df.input_col)).collect()[0][0]
1400+
self.assertEqual(row, expected)
1401+
13821402

13831403
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13841404
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,6 +3437,41 @@ def eval(self):
34373437
udtf(TestUDTF, returnType=ret_type)().collect()
34383438

34393439

3440+
def test_udtf_with_collated_string_types(self):
3441+
class MyUDTF:
3442+
def eval(self, val1, val2, val3, val4):
3443+
yield (val1 + "1", val2 + "2", val3 + "3", val4 + "4")
3444+
3445+
my_udtf = udtf(
3446+
MyUDTF,
3447+
returnType=StructType(
3448+
[
3449+
StructField("out1", StringType()),
3450+
StructField("out2", StringType()),
3451+
StructField("out3", StringType()),
3452+
StructField("out4", StringType()),
3453+
]
3454+
),
3455+
)
3456+
3457+
schema = StructType(
3458+
[
3459+
StructField("col1", StringType(), True),
3460+
StructField("col2", StringType("UTF8_BINARY"), True),
3461+
StructField("col3", StringType("UTF8_LCASE"), True),
3462+
StructField("col4", StringType("UNICODE"), True),
3463+
]
3464+
)
3465+
3466+
data = [("hello", "hello", "hello", "hello")]
3467+
df = self.spark.createDataFrame(data, schema=schema)
3468+
3469+
result_row = df.select(my_udtf(df.col1, df.col2, df.col3, df.col4)).collect()[0]
3470+
3471+
expected = ("hello1", "hello2", "hello3", "hello4")
3472+
self.assertEqual(result_row, expected)
3473+
3474+
34403475
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
34413476
@classmethod
34423477
def setUpClass(cls):

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ object EvaluatePython {
7979

8080
case (d: Decimal, _) => d.toJavaBigDecimal
8181

82-
case (s: UTF8String, StringType) => s.toString
82+
case (s: UTF8String, _: StringType) => s.toString
8383

8484
case (other, _) => other
8585
}

0 commit comments

Comments
 (0)