@@ -50,6 +50,8 @@ class ArrowUDFType:
5050
5151 GROUPED_AGG = PythonEvalType .SQL_GROUPED_AGG_ARROW_UDF
5252
53+ GROUPED_AGG_ITER = PythonEvalType .SQL_GROUPED_AGG_ARROW_ITER_UDF
54+
5355
5456def arrow_udf (f = None , returnType = None , functionType = None ):
5557 """
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
301303 Therefore, mutating the input arrays is not allowed and will cause incorrect results.
302304 For the same reason, users should also not rely on the index of the input arrays.
303305
306+ * Iterator of Arrays to Scalar
307+ `Iterator[pyarrow.Array]` -> `Any`
308+
309+ The function takes an iterator of `pyarrow.Array` and returns a scalar value. This is
310+ useful for grouped aggregations where the UDF can process all batches for a group
311+ iteratively, which is more memory-efficient than loading all data at once. The returned
312+ scalar can be a python primitive type, a numpy data type, or a `pyarrow.Scalar` instance.
313+
314+ >>> import pandas as pd
315+ >>> from typing import Iterator
316+ >>> @arrow_udf("double")
317+ ... def arrow_mean(it: Iterator[pa.Array]) -> float:
318+ ... sum_val = 0.0
319+ ... cnt = 0
320+ ... for v in it:
321+ ... assert isinstance(v, pa.Array)
322+ ... sum_val += pa.compute.sum(v).as_py()
323+ ... cnt += len(v)
324+ ... return sum_val / cnt
325+ ...
326+ >>> df = spark.createDataFrame(
327+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
328+ >>> df.groupby("id").agg(arrow_mean(df['v'])).show() # doctest: +SKIP
329+ +---+---------------+
330+ | id|arrow_mean(v) |
331+ +---+---------------+
332+ | 1| 1.5|
333+ | 2| 6.0|
334+ +---+---------------+
335+
336+ * Iterator of Multiple Arrays to Scalar
337+ `Iterator[Tuple[pyarrow.Array, ...]]` -> `Any`
338+
339+ The function takes an iterator of a tuple of multiple `pyarrow.Array` and returns a
340+ scalar value. This is useful for grouped aggregations with multiple input columns.
341+
342+ >>> from typing import Iterator, Tuple
343+ >>> import numpy as np
344+ >>> @arrow_udf("double")
345+ ... def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
346+ ... weighted_sum = 0.0
347+ ... weight = 0.0
348+ ... for v, w in it:
349+ ... assert isinstance(v, pa.Array)
350+ ... assert isinstance(w, pa.Array)
351+ ... weighted_sum += np.dot(v, w)
352+ ... weight += pa.compute.sum(w).as_py()
353+ ... return weighted_sum / weight
354+ ...
355+ >>> df = spark.createDataFrame(
356+ ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
357+ ... ("id", "v", "w"))
358+ >>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() # doctest: +SKIP
359+ +---+---------------------------------+
360+ | id|arrow_weighted_mean(v, w) |
361+ +---+---------------------------------+
362+ | 1| 1.6666666666666667|
363+ | 2| 7.166666666666667|
364+ +---+---------------------------------+
365+
304366 Notes
305367 -----
306368 The user-defined functions do not support conditional expressions or short circuiting
@@ -720,6 +782,7 @@ def vectorized_udf(
720782 PythonEvalType .SQL_SCALAR_ARROW_UDF ,
721783 PythonEvalType .SQL_SCALAR_ARROW_ITER_UDF ,
722784 PythonEvalType .SQL_GROUPED_AGG_ARROW_UDF ,
785+ PythonEvalType .SQL_GROUPED_AGG_ARROW_ITER_UDF ,
723786 None ,
724787 ]: # None means it should infer the type from type hints.
725788 raise PySparkTypeError (
@@ -768,6 +831,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
768831 PythonEvalType .SQL_SCALAR_ARROW_UDF ,
769832 PythonEvalType .SQL_SCALAR_ARROW_ITER_UDF ,
770833 PythonEvalType .SQL_GROUPED_AGG_ARROW_UDF ,
834+ PythonEvalType .SQL_GROUPED_AGG_ARROW_ITER_UDF ,
771835 ]:
772836 warnings .warn (
773837 "It is preferred to specify type hints for "
0 commit comments