-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF #53035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF #53035
Conversation
c09d7f2 to
ade3730
Compare
ade3730 to
fe79337
Compare
fe79337 to
9eec72f
Compare
| iteratively, which is more memory-efficient than loading all data at once. The returned | ||
| scalar can be a python primitive type, a numpy data type, or a `pyarrow.Scalar` instance. | ||
| >>> import pandas as pd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pandas is not used?
| ... | ||
| >>> df = spark.createDataFrame( | ||
| ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) | ||
| >>> df.groupby("id").agg(arrow_mean(df['v'])).show() # doctest: +SKIP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not skip the doctests
| >>> df = spark.createDataFrame( | ||
| ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)], | ||
| ... ("id", "v", "w")) | ||
| >>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() # doctest: +SKIP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| | 1| 1.6666666666666667| | ||
| | 2| 7.166666666666667| |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| | 1| 1.6666666666666667| | |
| | 2| 7.166666666666667| | |
| | 1| 1.6666666666666...| | |
| | 2| 7.166666666666...| |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest don't compare the exact values since they may vary due to env/version changes.
|
|
||
|
|
||
| # Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF | ||
| class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should consolidate it with ArrowStreamAggArrowUDFSerializer: make ArrowStreamAggArrowUDFSerializer output the iterator and adjust the wrapper of SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we do it as a follow up?
| if is_iterator_array: | ||
| return ArrowUDFType.SCALAR_ITER | ||
|
|
||
| # Iterator[Tuple[pa.Array, ...]] -> Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's move the new inference after pa.Array, ... -> Any
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved
| dataframes_in_group = read_int(stream) | ||
|
|
||
| if dataframes_in_group == 1: | ||
| batches = list(ArrowStreamSerializer.load_stream(self, stream)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not load all batches in a group, this new API is designed to process each group in an incremental approach
|
also cc @Kimahriman I guess you might be also interested in this PR |
What changes were proposed in this pull request?
This PR introduces an iterator API for Arrow grouped aggregation UDFs in PySpark. It adds support for two new UDF patterns:
Iterator[pa.Array] -> Anyfor single column aggregationsIterator[Tuple[pa.Array, ...]] -> Anyfor multiple column aggregationsThe implementation adds a new Python eval type
SQL_GROUPED_AGG_ARROW_ITER_UDFwith corresponding support in type inference, worker serialization, and Scala execution planning.Why are the changes needed?
The current Arrow grouped aggregation API requires loading all data for a group into memory at once, which can be problematic for groups with large amounts of data. The iterator API allows processing data in batches, providing:
SQL_SCALAR_ARROW_ITER_UDF)Does this PR introduce any user-facing change?
Yes. This PR adds a new API pattern for Arrow grouped aggregation UDFs:
Single column aggregation:
Multiple column aggregation:
How was this patch tested?
Added comprehensive unit tests in
python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py:test_iterator_grouped_agg_single_column()- Tests single column iterator aggregation withIterator[pa.Array]test_iterator_grouped_agg_multiple_columns()- Tests multiple column iterator aggregation withIterator[Tuple[pa.Array, pa.Array]]test_iterator_grouped_agg_eval_type()- Verifies correct eval type inference from type hintsWas this patch authored or co-authored using generative AI tooling?
Co-Generated-by: Cursor with Claude Sonnet 4.5