Skip to content

Conversation

@Yicong-Huang
Copy link
Contributor

What changes were proposed in this pull request?

This PR introduces an iterator API for Arrow grouped aggregation UDFs in PySpark. It adds support for two new UDF patterns:

  • Iterator[pa.Array] -> Any for single column aggregations
  • Iterator[Tuple[pa.Array, ...]] -> Any for multiple column aggregations

The implementation adds a new Python eval type SQL_GROUPED_AGG_ARROW_ITER_UDF with corresponding support in type inference, worker serialization, and Scala execution planning.

Why are the changes needed?

The current Arrow grouped aggregation API requires loading all data for a group into memory at once, which can be problematic for groups with large amounts of data. The iterator API allows processing data in batches, providing:

  1. Memory Efficiency: Processes data incrementally rather than loading entire group into memory
  2. Consistency: Aligns with existing iterator APIs (e.g., SQL_SCALAR_ARROW_ITER_UDF)
  3. Flexibility: Allows initialization of expensive state once per group while processing batches iteratively

Does this PR introduce any user-facing change?

Yes. This PR adds a new API pattern for Arrow grouped aggregation UDFs:

Single column aggregation:

import pyarrow as pa
from typing import Iterator
from pyspark.sql.functions import arrow_udf

@arrow_udf("double")
def arrow_mean(it: Iterator[pa.Array]) -> float:
    sum_val = 0.0
    cnt = 0
    for v in it:
        sum_val += pa.compute.sum(v).as_py()
        cnt += len(v)
    return sum_val / cnt if cnt > 0 else 0.0

df.groupby("id").agg(arrow_mean(df['v'])).show()

Multiple column aggregation:

import pyarrow as pa
import numpy as np
from typing import Iterator, Tuple
from pyspark.sql.functions import arrow_udf

@arrow_udf("double")
def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
    weighted_sum = 0.0
    weight = 0.0
    for v, w in it:
        weighted_sum += np.dot(v.to_numpy(), w.to_numpy())
        weight += pa.compute.sum(w).as_py()
    return weighted_sum / weight if weight > 0 else 0.0

df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show()

How was this patch tested?

Added comprehensive unit tests in python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py:

  1. test_iterator_grouped_agg_single_column() - Tests single column iterator aggregation with Iterator[pa.Array]
  2. test_iterator_grouped_agg_multiple_columns() - Tests multiple column iterator aggregation with Iterator[Tuple[pa.Array, pa.Array]]
  3. test_iterator_grouped_agg_eval_type() - Verifies correct eval type inference from type hints

Was this patch authored or co-authored using generative AI tooling?

Co-Generated-by: Cursor with Claude Sonnet 4.5

@Yicong-Huang Yicong-Huang force-pushed the SPARK-53615/feat/arrow-grouped-agg-iterator-api branch 2 times, most recently from c09d7f2 to ade3730 Compare November 13, 2025 10:14
@Yicong-Huang Yicong-Huang changed the title [SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF [WIP][SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF Nov 13, 2025
@Yicong-Huang Yicong-Huang force-pushed the SPARK-53615/feat/arrow-grouped-agg-iterator-api branch from ade3730 to fe79337 Compare November 13, 2025 17:50
@Yicong-Huang Yicong-Huang force-pushed the SPARK-53615/feat/arrow-grouped-agg-iterator-api branch from fe79337 to 9eec72f Compare November 13, 2025 18:09
@Yicong-Huang Yicong-Huang changed the title [WIP][SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF [SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF Nov 13, 2025
iteratively, which is more memory-efficient than loading all data at once. The returned
scalar can be a python primitive type, a numpy data type, or a `pyarrow.Scalar` instance.
>>> import pandas as pd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas is not used?

...
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
>>> df.groupby("id").agg(arrow_mean(df['v'])).show() # doctest: +SKIP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not skip the doctests

>>> df = spark.createDataFrame(
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
... ("id", "v", "w"))
>>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() # doctest: +SKIP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 362 to 363
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
| 1| 1.6666666666666...|
| 2| 7.166666666666...|

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest don't compare the exact values since they may vary due to env/version changes.



# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should consolidate it with ArrowStreamAggArrowUDFSerializer: make ArrowStreamAggArrowUDFSerializer output the iterator and adjust the wrapper of SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we do it as a follow up?

if is_iterator_array:
return ArrowUDFType.SCALAR_ITER

# Iterator[Tuple[pa.Array, ...]] -> Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move the new inference after pa.Array, ... -> Any

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved

dataframes_in_group = read_int(stream)

if dataframes_in_group == 1:
batches = list(ArrowStreamSerializer.load_stream(self, stream))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not load all batches in a group, this new API is designed to process each group in an incremental approach

@zhengruifeng
Copy link
Contributor

also cc @Kimahriman I guess you might be also interested in this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants