Skip to content

Commit e5fa93c

Browse files
committed
temp
1 parent 0cd2f90 commit e5fa93c

File tree

7 files changed

+90
-5
lines changed

7 files changed

+90
-5
lines changed

python/pyspark/core/context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2649,9 +2649,20 @@ def _to_ddl(self, struct: "StructType") -> str:
26492649

26502650
def _parse_ddl(self, ddl: str) -> "DataType":
26512651
from pyspark.sql.types import _parse_datatype_json_string
2652+
from pyspark.sql.utils import ParseException
26522653

26532654
assert self._jvm is not None
2654-
return _parse_datatype_json_string(self._jvm.PythonSQLUtils.ddlToJson(ddl))
2655+
try:
2656+
# This hack is introduced because of collated strings. E.g., if the return type is
2657+
# just `STRING COLLATE FR`, `ddlToJson()` would return just `STRING`, losing
2658+
# information about collation. This is because collation metadata is stored in the
2659+
# nearest ancestor of `StructField`, so that's why we wrap the return type with
2660+
# `StructType`.
2661+
wrapped_jvm_returnType = self._jvm.PythonSQLUtils.parseDataType(f"struct<ddl: {ddl}>")
2662+
wrapped_returnType = _parse_datatype_json_string(wrapped_jvm_returnType.json())
2663+
return wrapped_returnType["ddl"].dataType
2664+
except ParseException:
2665+
return _parse_datatype_json_string(self._jvm.PythonSQLUtils.ddlToJson(ddl))
26552666

26562667

26572668
def _test() -> None:

python/pyspark/sql/tests/test_udf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,29 @@ def add1(x):
13791379
result = empty_df.select(add1("id"))
13801380
self.assertEqual(result.collect(), [])
13811381

1382+
def test_udf_with_collated_string_types(self):
1383+
@udf("string collate fr")
1384+
def my_udf(input_val):
1385+
return "%s - %s" % (type(input_val), input_val)
1386+
1387+
string_types = [
1388+
StringType(),
1389+
StringType("UTF8_BINARY"),
1390+
StringType("UTF8_LCASE"),
1391+
StringType("UNICODE"),
1392+
]
1393+
data = [("hello",)]
1394+
expected = "<class 'str'> - hello"
1395+
1396+
for string_type in string_types:
1397+
schema = StructType([StructField("input_col", string_type, True)])
1398+
df = self.spark.createDataFrame(data, schema=schema)
1399+
df_result = df.select(my_udf(df.input_col).alias("result"))
1400+
row = df_result.collect()[0][0]
1401+
self.assertEqual(row, expected)
1402+
result_type = df_result.schema["result"].dataType
1403+
self.assertEqual(result_type, StringType("fr"))
1404+
13821405

13831406
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13841407
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,6 +3490,41 @@ def eval(self):
34903490
udtf(TestUDTF, returnType=ret_type)().collect()
34913491

34923492

3493+
def test_udtf_with_collated_string_types(self):
3494+
@udtf(
3495+
"out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE,"
3496+
" out4 string collate UNICODE"
3497+
)
3498+
class MyUDTF:
3499+
def eval(self, v1, v2, v3, v4):
3500+
yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")
3501+
3502+
schema = StructType(
3503+
[
3504+
StructField("col1", StringType(), True),
3505+
StructField("col2", StringType("UTF8_BINARY"), True),
3506+
StructField("col3", StringType("UTF8_LCASE"), True),
3507+
StructField("col4", StringType("UNICODE"), True),
3508+
]
3509+
)
3510+
df = self.spark.createDataFrame([("hello",) * 4], schema=schema)
3511+
3512+
df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
3513+
result_df = df_out.select("out.*")
3514+
3515+
expected_row = ("hello1", "hello2", "hello3", "hello4")
3516+
self.assertEqual(result_df.collect()[0], expected_row)
3517+
3518+
expected_output_types = [
3519+
StringType(),
3520+
StringType("UTF8_BINARY"),
3521+
StringType("UTF8_LCASE"),
3522+
StringType("UNICODE"),
3523+
]
3524+
for idx, field in enumerate(result_df.schema.fields):
3525+
self.assertEqual(field.dataType, expected_output_types[idx])
3526+
3527+
34933528
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
34943529
@classmethod
34953530
def setUpClass(cls):

python/pyspark/sql/udf.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,12 +382,20 @@ def _judf(self) -> "JavaObject":
382382

383383
def _create_judf(self, func: Callable[..., Any]) -> "JavaObject":
384384
from pyspark.sql import SparkSession
385+
from pyspark.sql.types import StructField
385386

386387
spark = SparkSession._getActiveSessionOrCreate()
387388
sc = spark.sparkContext
388389

389390
wrapped_func = _wrap_function(sc, func, self.returnType)
390-
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
391+
# This hack is introduced because of collated strings. E.g., if the return type is
392+
# just `StringType("FR")`, `json()` would return just `STRING`, losing information
393+
# about collation. This is because collation metadata is stored in the nearest
394+
# ancestor of `StructField`, so that's why we wrap the return type with
395+
# `StructType`.
396+
wrapped_returnType = StructType([StructField("returnType", self.returnType)])
397+
wrapped_jvm_returnType = spark._jsparkSession.parseDataType(wrapped_returnType.json())
398+
jdt = wrapped_jvm_returnType.fields()[0].dataType()
391399
assert sc._jvm is not None
392400
judf = getattr(sc._jvm, "org.apache.spark.sql.execution.python.UserDefinedPythonFunction")(
393401
self._name, wrapped_func, jdt, self.evalType, self.deterministic

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
2727
import org.apache.spark.sql.execution.SparkPlan
2828
import org.apache.spark.sql.execution.metric.SQLMetric
2929
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
30-
import org.apache.spark.sql.types.{StructType, UserDefinedType}
30+
import org.apache.spark.sql.types.{StringType, StructType, UserDefinedType}
3131

3232
/**
3333
* Grouped a iterator into batches.
@@ -109,6 +109,10 @@ class ArrowEvalPythonEvaluatorFactory(
109109

110110
val outputTypes = output.drop(childOutput.length).map(_.dataType.transformRecursively {
111111
case udt: UserDefinedType[_] => udt.sqlType
112+
// We change each StringType, with StringType companion object, to ignore collations.
113+
// This is because Python doesn't know about collations, and will always return non-collated
114+
// strings.
115+
case _: StringType => StringType
112116
})
113117

114118
val batchIter = Iterator(iter)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.errors.QueryExecutionErrors
2626
import org.apache.spark.sql.execution.SparkPlan
2727
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
28-
import org.apache.spark.sql.types.{StructType, UserDefinedType}
28+
import org.apache.spark.sql.types.{StringType, StructType, UserDefinedType}
2929
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
3030

3131
/**
@@ -63,6 +63,10 @@ case class ArrowEvalPythonUDTFExec(
6363

6464
val outputTypes = resultAttrs.map(_.dataType.transformRecursively {
6565
case udt: UserDefinedType[_] => udt.sqlType
66+
// We change each StringType, with StringType companion object, to ignore collations.
67+
// This is because Python doesn't know about collations, and will always return non-collated
68+
// strings.
69+
case _: StringType => StringType
6670
})
6771

6872
val columnarBatchIter = new ArrowPythonUDTFRunner(

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ object EvaluatePython {
7979

8080
case (d: Decimal, _) => d.toJavaBigDecimal
8181

82-
case (s: UTF8String, StringType) => s.toString
82+
case (s: UTF8String, _: StringType) => s.toString
8383

8484
case (other, _) => other
8585
}

0 commit comments

Comments
 (0)