Skip to content

Commit 9eec72f

Browse files
committed
feat: introduce iterator API for Arrow grouped agg UDF
1 parent 05b0543 commit 9eec72f

File tree

10 files changed

+311
-3
lines changed

10 files changed

+311
-3
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ private[spark] object PythonEvalType {
7474
val SQL_SCALAR_ARROW_ITER_UDF = 251
7575
val SQL_GROUPED_AGG_ARROW_UDF = 252
7676
val SQL_WINDOW_AGG_ARROW_UDF = 253
77+
val SQL_GROUPED_AGG_ARROW_ITER_UDF = 254
7778

7879
val SQL_TABLE_UDF = 300
7980
val SQL_ARROW_TABLE_UDF = 301
@@ -111,6 +112,7 @@ private[spark] object PythonEvalType {
111112
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
112113
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
113114
case SQL_WINDOW_AGG_ARROW_UDF => "SQL_WINDOW_AGG_ARROW_UDF"
115+
case SQL_GROUPED_AGG_ARROW_ITER_UDF => "SQL_GROUPED_AGG_ARROW_ITER_UDF"
114116
}
115117
}
116118

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ ArrowScalarUDFType = Literal[250]
6666
ArrowScalarIterUDFType = Literal[251]
6767
ArrowGroupedAggUDFType = Literal[252]
6868
ArrowWindowAggUDFType = Literal[253]
69+
ArrowGroupedAggIterUDFType = Literal[254]
6970

7071
class ArrowVariadicScalarToScalarFunction(Protocol):
7172
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...

python/pyspark/sql/pandas/functions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class ArrowUDFType:
5050

5151
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
5252

53+
GROUPED_AGG_ITER = PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF
54+
5355

5456
def arrow_udf(f=None, returnType=None, functionType=None):
5557
"""
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
301303
Therefore, mutating the input arrays is not allowed and will cause incorrect results.
302304
For the same reason, users should also not rely on the index of the input arrays.
303305
306+
* Iterator of Arrays to Scalar
307+
`Iterator[pyarrow.Array]` -> `Any`
308+
309+
The function takes an iterator of `pyarrow.Array` and returns a scalar value. This is
310+
useful for grouped aggregations where the UDF can process all batches for a group
311+
iteratively, which is more memory-efficient than loading all data at once. The returned
312+
scalar can be a python primitive type, a numpy data type, or a `pyarrow.Scalar` instance.
313+
314+
>>> import pandas as pd
315+
>>> from typing import Iterator
316+
>>> @arrow_udf("double")
317+
... def arrow_mean(it: Iterator[pa.Array]) -> float:
318+
... sum_val = 0.0
319+
... cnt = 0
320+
... for v in it:
321+
... assert isinstance(v, pa.Array)
322+
... sum_val += pa.compute.sum(v).as_py()
323+
... cnt += len(v)
324+
... return sum_val / cnt
325+
...
326+
>>> df = spark.createDataFrame(
327+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
328+
>>> df.groupby("id").agg(arrow_mean(df['v'])).show() # doctest: +SKIP
329+
+---+---------------+
330+
| id|arrow_mean(v) |
331+
+---+---------------+
332+
| 1| 1.5|
333+
| 2| 6.0|
334+
+---+---------------+
335+
336+
* Iterator of Multiple Arrays to Scalar
337+
`Iterator[Tuple[pyarrow.Array, ...]]` -> `Any`
338+
339+
The function takes an iterator of a tuple of multiple `pyarrow.Array` and returns a
340+
scalar value. This is useful for grouped aggregations with multiple input columns.
341+
342+
>>> from typing import Iterator, Tuple
343+
>>> import numpy as np
344+
>>> @arrow_udf("double")
345+
... def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
346+
... weighted_sum = 0.0
347+
... weight = 0.0
348+
... for v, w in it:
349+
... assert isinstance(v, pa.Array)
350+
... assert isinstance(w, pa.Array)
351+
... weighted_sum += np.dot(v, w)
352+
... weight += pa.compute.sum(w).as_py()
353+
... return weighted_sum / weight
354+
...
355+
>>> df = spark.createDataFrame(
356+
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
357+
... ("id", "v", "w"))
358+
>>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() # doctest: +SKIP
359+
+---+---------------------------------+
360+
| id|arrow_weighted_mean(v, w) |
361+
+---+---------------------------------+
362+
| 1| 1.6666666666666667|
363+
| 2| 7.166666666666667|
364+
+---+---------------------------------+
365+
304366
Notes
305367
-----
306368
The user-defined functions do not support conditional expressions or short circuiting
@@ -720,6 +782,7 @@ def vectorized_udf(
720782
PythonEvalType.SQL_SCALAR_ARROW_UDF,
721783
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
722784
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
785+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
723786
None,
724787
]: # None means it should infer the type from type hints.
725788
raise PySparkTypeError(
@@ -768,6 +831,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
768831
PythonEvalType.SQL_SCALAR_ARROW_UDF,
769832
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
770833
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
834+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
771835
]:
772836
warnings.warn(
773837
"It is preferred to specify type hints for "

python/pyspark/sql/pandas/serializers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,72 @@ def __repr__(self):
12001200
return "ArrowStreamAggArrowUDFSerializer"
12011201

12021202

1203+
# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
1204+
class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
1205+
def __init__(
1206+
self,
1207+
timezone,
1208+
safecheck,
1209+
assign_cols_by_name,
1210+
arrow_cast,
1211+
):
1212+
super().__init__(
1213+
timezone=timezone,
1214+
safecheck=safecheck,
1215+
assign_cols_by_name=False,
1216+
arrow_cast=True,
1217+
)
1218+
self._timezone = timezone
1219+
self._safecheck = safecheck
1220+
self._assign_cols_by_name = assign_cols_by_name
1221+
self._arrow_cast = arrow_cast
1222+
1223+
def load_stream(self, stream):
1224+
"""
1225+
Yield column iterators instead of concatenating batches.
1226+
Each group yields a structure where indexing by column offset gives an iterator of arrays.
1227+
"""
1228+
import pyarrow as pa
1229+
1230+
dataframes_in_group = None
1231+
1232+
while dataframes_in_group is None or dataframes_in_group > 0:
1233+
dataframes_in_group = read_int(stream)
1234+
1235+
if dataframes_in_group == 1:
1236+
batches = list(ArrowStreamSerializer.load_stream(self, stream))
1237+
# Create a structure that can be indexed by column offset to get column iterators
1238+
# The mapper will do a[offset] to get each column's iterator
1239+
if len(batches) > 0:
1240+
num_cols = batches[0].num_columns
1241+
1242+
# Create a custom class that can be indexed to get column iterators
1243+
class ColumnIterators:
1244+
def __init__(self, batches, num_cols):
1245+
self._batches = batches
1246+
self._num_cols = num_cols
1247+
1248+
def __getitem__(self, col_idx):
1249+
return (batch.column(col_idx) for batch in self._batches)
1250+
1251+
def __len__(self):
1252+
return self._num_cols
1253+
1254+
yield ColumnIterators(batches, num_cols)
1255+
else:
1256+
# Empty group
1257+
yield []
1258+
1259+
elif dataframes_in_group != 0:
1260+
raise PySparkValueError(
1261+
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
1262+
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
1263+
)
1264+
1265+
def __repr__(self):
1266+
return "ArrowStreamAggArrowIterUDFSerializer"
1267+
1268+
12031269
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
12041270
def __init__(
12051271
self,

python/pyspark/sql/pandas/typehints.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ArrowScalarUDFType,
3030
ArrowScalarIterUDFType,
3131
ArrowGroupedAggUDFType,
32+
ArrowGroupedAggIterUDFType,
3233
ArrowGroupedMapIterUDFType,
3334
ArrowGroupedMapUDFType,
3435
ArrowGroupedMapFunction,
@@ -156,7 +157,14 @@ def infer_pandas_eval_type(
156157

157158
def infer_arrow_eval_type(
158159
sig: Signature, type_hints: Dict[str, Any]
159-
) -> Optional[Union["ArrowScalarUDFType", "ArrowScalarIterUDFType", "ArrowGroupedAggUDFType"]]:
160+
) -> Optional[
161+
Union[
162+
"ArrowScalarUDFType",
163+
"ArrowScalarIterUDFType",
164+
"ArrowGroupedAggUDFType",
165+
"ArrowGroupedAggIterUDFType",
166+
]
167+
]:
160168
"""
161169
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
162170
:class:`inspect.Signature` instance and type hints.
@@ -226,6 +234,41 @@ def infer_arrow_eval_type(
226234
if is_iterator_array:
227235
return ArrowUDFType.SCALAR_ITER
228236

237+
# Iterator[Tuple[pa.Array, ...]] -> Any
238+
is_iterator_tuple_array_agg = (
239+
len(parameters_sig) == 1
240+
and check_iterator_annotation( # Iterator
241+
parameters_sig[0],
242+
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
243+
a,
244+
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
245+
),
246+
)
247+
and (
248+
return_annotation != pa.Array
249+
and not check_iterator_annotation(return_annotation)
250+
and not check_tuple_annotation(return_annotation)
251+
)
252+
)
253+
if is_iterator_tuple_array_agg:
254+
return ArrowUDFType.GROUPED_AGG_ITER
255+
256+
# Iterator[pa.Array] -> Any
257+
is_iterator_array_agg = (
258+
len(parameters_sig) == 1
259+
and check_iterator_annotation(
260+
parameters_sig[0],
261+
parameter_check_func=lambda a: a == pa.Array,
262+
)
263+
and (
264+
return_annotation != pa.Array
265+
and not check_iterator_annotation(return_annotation)
266+
and not check_tuple_annotation(return_annotation)
267+
)
268+
)
269+
if is_iterator_array_agg:
270+
return ArrowUDFType.GROUPED_AGG_ITER
271+
229272
# pa.Array, ... -> Any
230273
is_array_agg = all(a == pa.Array for a in parameters_sig) and (
231274
return_annotation != pa.Array
@@ -249,6 +292,7 @@ def infer_eval_type(
249292
"ArrowScalarUDFType",
250293
"ArrowScalarIterUDFType",
251294
"ArrowGroupedAggUDFType",
295+
"ArrowGroupedAggIterUDFType",
252296
]:
253297
"""
254298
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
@@ -264,6 +308,7 @@ def infer_eval_type(
264308
"ArrowScalarUDFType",
265309
"ArrowScalarIterUDFType",
266310
"ArrowGroupedAggUDFType",
311+
"ArrowGroupedAggIterUDFType",
267312
]
268313
] = None
269314
if kind == "pandas":
@@ -295,6 +340,7 @@ def infer_eval_type_for_udf( # type: ignore[no-untyped-def]
295340
"ArrowScalarUDFType",
296341
"ArrowScalarIterUDFType",
297342
"ArrowGroupedAggUDFType",
343+
"ArrowGroupedAggIterUDFType",
298344
]
299345
]:
300346
argspec = getfullargspec(f)

python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,90 @@ def my_grouped_agg_arrow_udf(x):
10591059
],
10601060
)
10611061

1062+
def test_iterator_grouped_agg_single_column(self):
1063+
"""
1064+
Test iterator API for grouped aggregation with single column.
1065+
"""
1066+
import pyarrow as pa
1067+
from typing import Iterator
1068+
1069+
@arrow_udf("double")
1070+
def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
1071+
sum_val = 0.0
1072+
cnt = 0
1073+
for v in it:
1074+
assert isinstance(v, pa.Array)
1075+
sum_val += pa.compute.sum(v).as_py()
1076+
cnt += len(v)
1077+
return sum_val / cnt if cnt > 0 else 0.0
1078+
1079+
df = self.spark.createDataFrame(
1080+
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
1081+
)
1082+
1083+
result = df.groupby("id").agg(arrow_mean_iter(df["v"]).alias("mean")).sort("id")
1084+
expected = df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
1085+
1086+
self.assertEqual(expected, result.collect())
1087+
1088+
@unittest.skipIf(not have_numpy, numpy_requirement_message)
1089+
def test_iterator_grouped_agg_multiple_columns(self):
1090+
"""
1091+
Test iterator API for grouped aggregation with multiple columns.
1092+
"""
1093+
import pyarrow as pa
1094+
import numpy as np
1095+
from typing import Iterator, Tuple
1096+
1097+
@arrow_udf("double")
1098+
def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
1099+
weighted_sum = 0.0
1100+
weight = 0.0
1101+
for v, w in it:
1102+
assert isinstance(v, pa.Array)
1103+
assert isinstance(w, pa.Array)
1104+
weighted_sum += np.dot(v, w)
1105+
weight += pa.compute.sum(w).as_py()
1106+
return weighted_sum / weight if weight > 0 else 0.0
1107+
1108+
df = self.spark.createDataFrame(
1109+
[(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
1110+
("id", "v", "w"),
1111+
)
1112+
1113+
result = (
1114+
df.groupby("id")
1115+
.agg(arrow_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
1116+
.sort("id")
1117+
.collect()
1118+
)
1119+
1120+
# Expected weighted means:
1121+
# Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0 = 1.6666666666666667
1122+
# Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 / 6.0 = 7.166666666666667
1123+
expected = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
1124+
1125+
self.assertEqual(len(result), len(expected))
1126+
for r, (exp_id, exp_wm) in zip(result, expected):
1127+
self.assertEqual(r["id"], exp_id)
1128+
self.assertAlmostEqual(r["wm"], exp_wm, places=5)
1129+
1130+
def test_iterator_grouped_agg_eval_type(self):
1131+
"""
1132+
Test that the eval type is correctly inferred for iterator grouped agg UDFs.
1133+
"""
1134+
import pyarrow as pa
1135+
from typing import Iterator
1136+
1137+
@arrow_udf("double")
1138+
def arrow_sum_iter(it: Iterator[pa.Array]) -> float:
1139+
total = 0.0
1140+
for v in it:
1141+
total += pa.compute.sum(v).as_py()
1142+
return total
1143+
1144+
self.assertEqual(arrow_sum_iter.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF)
1145+
10621146

10631147
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
10641148
pass

python/pyspark/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
ArrowScalarUDFType,
7171
ArrowScalarIterUDFType,
7272
ArrowGroupedAggUDFType,
73+
ArrowGroupedAggIterUDFType,
7374
ArrowWindowAggUDFType,
7475
)
7576
from pyspark.sql._typing import (
@@ -660,6 +661,7 @@ class PythonEvalType:
660661
SQL_SCALAR_ARROW_ITER_UDF: "ArrowScalarIterUDFType" = 251
661662
SQL_GROUPED_AGG_ARROW_UDF: "ArrowGroupedAggUDFType" = 252
662663
SQL_WINDOW_AGG_ARROW_UDF: "ArrowWindowAggUDFType" = 253
664+
SQL_GROUPED_AGG_ARROW_ITER_UDF: "ArrowGroupedAggIterUDFType" = 254
663665

664666
SQL_TABLE_UDF: "SQLTableUDFType" = 300
665667
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301

0 commit comments

Comments
 (0)