Skip to content

Commit 686d844

Browse files
committed
[SPARK-53592][PYTHON] Make @udf support vectorized UDF
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html 2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'. 4. Be sure to keep the PR description updated to reflect all changes. 5. Please write your PR title to summarize what this PR proposes. 6. If possible, provide a concise example to reproduce the issue for a faster review. 7. If you want to add a new configuration, please read the guideline first for naming configurations in 'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'. 8. If you want to add or modify an error type or message, please read the guideline first in 'common/utils/src/main/resources/error/README.md'. --> ### What changes were proposed in this pull request? Make udf support vectorized UDF ### Why are the changes needed? to prompt vectorized UDF ### Does this PR introduce _any_ user-facing change? `udf` will try to infer the eval type based on the type hints For example, ```python udf(returnType=LongType()) def pd_add1(ser: pd.Series) -> pd.Series: assert isinstance(ser, pd.Series) return ser + 1 ``` The inferred type is `PythonEvalType.SQL_SCALAR_PANDAS_UDF` ### How was this patch tested? added UTs ### Was this patch authored or co-authored using generative AI tooling? no Closes #52323 from zhengruifeng/unify_udf. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 589141e commit 686d844

File tree

6 files changed

+544
-8
lines changed

6 files changed

+544
-8
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def __hash__(self):
587587
"pyspark.sql.tests.test_udf",
588588
"pyspark.sql.tests.test_udf_combinations",
589589
"pyspark.sql.tests.test_udf_profiler",
590+
"pyspark.sql.tests.test_unified_udf",
590591
"pyspark.sql.tests.test_udtf",
591592
"pyspark.sql.tests.test_tvf",
592593
"pyspark.sql.tests.test_utils",
@@ -1107,6 +1108,7 @@ def __hash__(self):
11071108
"pyspark.sql.tests.connect.test_parity_udf",
11081109
"pyspark.sql.tests.connect.test_parity_udf_combinations",
11091110
"pyspark.sql.tests.connect.test_parity_udf_profiler",
1111+
"pyspark.sql.tests.connect.test_parity_unified_udf",
11101112
"pyspark.sql.tests.connect.test_parity_memory_profiler",
11111113
"pyspark.sql.tests.connect.test_parity_udtf",
11121114
"pyspark.sql.tests.connect.test_parity_tvf",

python/pyspark/sql/connect/udf.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def _create_py_udf(
7777
else:
7878
is_arrow_enabled = useArrow
7979

80-
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
81-
8280
if is_arrow_enabled:
83-
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
8481
try:
8582
require_minimum_pandas_version()
8683
require_minimum_pyarrow_version()
@@ -92,6 +89,25 @@ def _create_py_udf(
9289
RuntimeWarning,
9390
)
9491

92+
eval_type: Optional[int] = None
93+
if useArrow is None:
94+
# If the user doesn't explicitly set useArrow
95+
from pyspark.sql.pandas.typehints import infer_eval_type_from_func
96+
97+
try:
98+
# Try to infer the eval type from type hints
99+
eval_type = infer_eval_type_from_func(f)
100+
except Exception:
101+
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)
102+
103+
if eval_type is None:
104+
if is_arrow_enabled:
105+
# Arrow optimized Python UDF
106+
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
107+
else:
108+
# Fallback to Regular Python UDF
109+
eval_type = PythonEvalType.SQL_BATCHED_UDF
110+
95111
return _create_udf(f, returnType, eval_type)
96112

97113

python/pyspark/sql/pandas/typehints.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616
#
1717
from inspect import Signature
18-
from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING
18+
from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING, get_type_hints
19+
from inspect import getfullargspec, signature
1920

2021
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
2122
from pyspark.errors import PySparkNotImplementedError, PySparkValueError
@@ -277,6 +278,29 @@ def infer_eval_type(
277278
return eval_type
278279

279280

281+
def infer_eval_type_from_func( # type: ignore[no-untyped-def]
282+
f,
283+
) -> Optional[
284+
Union[
285+
"PandasScalarUDFType",
286+
"PandasScalarIterUDFType",
287+
"PandasGroupedAggUDFType",
288+
"ArrowScalarUDFType",
289+
"ArrowScalarIterUDFType",
290+
"ArrowGroupedAggUDFType",
291+
]
292+
]:
293+
argspec = getfullargspec(f)
294+
if len(argspec.annotations) > 0:
295+
try:
296+
type_hints = get_type_hints(f)
297+
except NameError:
298+
type_hints = {}
299+
return infer_eval_type(signature(f), type_hints)
300+
else:
301+
return None
302+
303+
280304
def check_tuple_annotation(
281305
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None
282306
) -> bool:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.sql.tests.test_unified_udf import UnifiedUDFTestsMixin
21+
from pyspark.testing.connectutils import ReusedConnectTestCase
22+
23+
24+
class UnifiedUDFParityTests(UnifiedUDFTestsMixin, ReusedConnectTestCase):
25+
@classmethod
26+
def setUpClass(cls):
27+
ReusedConnectTestCase.setUpClass()
28+
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")
29+
30+
31+
if __name__ == "__main__":
32+
from pyspark.sql.tests.connect.test_parity_unified_udf import * # noqa: F401
33+
34+
try:
35+
import xmlrunner # type: ignore[import]
36+
37+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
38+
except ImportError:
39+
testRunner = None
40+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)