Skip to content
114 changes: 113 additions & 1 deletion python/sedonadb/tests/functions/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,126 @@
# under the License.

import pytest
import shapely
from sedonadb.testing import PostGIS, SedonaDB


# Aggregate functions don't have a suffix in PostGIS
def agg_fn_suffix(eng):
"""Return the appropriate suffix for the aggregate function for the given engine."""
return "" if isinstance(eng, PostGIS) else "_Agg"


# ST_Envelope is not an aggregate function in PostGIS but we can check
# behaviour using ST_Envelope(ST_Collect(...))
def call_st_envelope_agg(eng, arg):
if isinstance(eng, PostGIS):
return f"ST_Envelope(ST_Collect({arg}))"
else:
return f"ST_Envelope_Agg({arg})"


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_points(eng):
eng = eng.create_or_skip()

eng.assert_query_result(
f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
VALUES
('POINT (1 2)'),
('POINT (3 4)'),
(NULL)
) AS t(geom)""",
"POLYGON ((1 2, 1 4, 3 4, 3 2, 1 2))",
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_all_null(eng):
eng = eng.create_or_skip()

eng.assert_query_result(
f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
VALUES
(NULL),
(NULL),
(NULL)
) AS t(geom)""",
None,
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_zero_input(eng):
eng = eng.create_or_skip()

eng.assert_query_result(
f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} AS empty FROM (
VALUES
('POINT (1 2)')
) AS t(geom) WHERE false""",
None,
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_single_point(eng):
eng = eng.create_or_skip()

eng.assert_query_result(
f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
VALUES ('POINT (5 5)')
) AS t(geom)""",
"POINT (5 5)",
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_collinear_points(eng):
eng = eng.create_or_skip()

eng.assert_query_result(
f"""SELECT {call_st_envelope_agg(eng, "ST_GeomFromText(geom)")} FROM (
VALUES
('POINT (0 0)'),
('POINT (0 1)'),
('POINT (0 2)')
) AS t(geom)""",
"LINESTRING (0 0, 0 2)",
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_agg_many_groups(eng, con):
eng = eng.create_or_skip()
num_groups = 1000

df_points = con.sql("""
SELECT id, geometry FROM sd_random_geometry('{"target_rows": 100000, "seed": 9728}')
""")
eng.create_table_arrow("df_points", df_points.to_arrow_table())

result = eng.execute_and_collect(
f"""
SELECT
(id % {num_groups})::INTEGER AS id_mod,
{call_st_envelope_agg(eng, "geometry")} AS envelope
FROM df_points
GROUP BY id_mod
ORDER BY id_mod
""",
)

df_points_geopandas = df_points.to_pandas()
expected = (
df_points_geopandas.groupby(df_points_geopandas["id"] % num_groups)["geometry"]
.apply(lambda group: shapely.box(*group.total_bounds))
.reset_index(name="envelope")
.rename(columns={"id": "id_mod"})
)

eng.assert_result(result, expected)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_collect_points(eng):
eng = eng.create_or_skip()
Expand Down
61 changes: 51 additions & 10 deletions rust/sedona-expr/src/aggregate_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow_schema::{DataType, FieldRef};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility,
};
use sedona_common::sedona_internal_err;
use sedona_schema::datatypes::SedonaType;
Expand Down Expand Up @@ -102,6 +102,18 @@ impl SedonaAggregateUDF {
&self.kernels
}

fn accumulator_arg_types(args: &AccumulatorArgs) -> Result<Vec<SedonaType>> {
let arg_fields = args
.exprs
.iter()
.map(|expr| expr.return_field(args.schema))
.collect::<Result<Vec<_>>>()?;
arg_fields
.iter()
.map(|field| SedonaType::from_storage_field(field))
.collect()
}

fn dispatch_impl(&self, args: &[SedonaType]) -> Result<(&dyn SedonaAccumulator, SedonaType)> {
// Resolve kernels in reverse so that more recently added ones are resolved first
for kernel in self.kernels.iter().rev() {
Expand Down Expand Up @@ -154,16 +166,27 @@ impl AggregateUDFImpl for SedonaAggregateUDF {
sedona_internal_err!("return_type() should not be called (use return_field())")
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
if let Ok(arg_types) = Self::accumulator_arg_types(&args) {
if let Ok((accumulator, _)) = self.dispatch_impl(&arg_types) {
return accumulator.groups_accumulator_supported(&arg_types);
}
}

false
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let arg_types = Self::accumulator_arg_types(&args)?;
let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
accumulator.groups_accumulator(&arg_types, &output_type)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let arg_fields = acc_args
.exprs
.iter()
.map(|expr| expr.return_field(acc_args.schema))
.collect::<Result<Vec<_>>>()?;
let arg_types = arg_fields
.iter()
.map(|field| SedonaType::from_storage_field(field))
.collect::<Result<Vec<_>>>()?;
let arg_types = Self::accumulator_arg_types(&acc_args)?;
let (accumulator, output_type) = self.dispatch_impl(&arg_types)?;
accumulator.accumulator(&arg_types, &output_type)
}
Expand All @@ -190,6 +213,24 @@ pub trait SedonaAccumulator: Debug {
output_type: &SedonaType,
) -> Result<Box<dyn Accumulator>>;

/// Given input data types, check if this implementation supports GroupsAccumulator
fn groups_accumulator_supported(&self, _args: &[SedonaType]) -> bool {
false
}

/// Given input data types, resolve a [GroupsAccumulator]
///
/// A GroupsAccumulator is an important optimization for aggregating many small groups,
/// particularly when such an aggregation is cheap. See the DataFusion documentation
/// for details.
fn groups_accumulator(
&self,
_args: &[SedonaType],
_output_type: &SedonaType,
) -> Result<Box<dyn GroupsAccumulator>> {
sedona_internal_err!("groups_accumulator not supported for {self:?}")
}

/// The fields representing the underlying serialized state of the Accumulator
fn state_fields(&self, args: &[SedonaType]) -> Result<Vec<FieldRef>>;
}
Expand Down
Loading