Skip to content
62 changes: 62 additions & 0 deletions python/sedonadb/tests/test_sjoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,65 @@ def test_spatial_join(join_type, on):
sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas()
assert len(sedonadb_results) > 0
eng_postgis.assert_query_result(sql, sedonadb_results)


@pytest.mark.parametrize(
"join_type", ["INNER JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN"]
)
@pytest.mark.parametrize(
"on",
[
"ST_Intersects(sjoin_geog1.geog, sjoin_geog2.geog)",
"ST_Distance(sjoin_geog1.geog, sjoin_geog2.geog) < 100000",
],
)
def test_spatial_join_geography(join_type, on):
eng_sedonadb = SedonaDB.create_or_skip()
eng_postgis = PostGIS.create_or_skip()

# Select two sets of bounding boxes that cross the antimeridian,
# which would be disjoint on a Euclidean plane. A geography join will produce non-empty results,
# whereas a geometry join would not.
west_most_bound = [-190, -10, -170, 10]
east_most_bound = [170, -10, 190, 10]
options = json.dumps(
{
"geom_type": "Point",
"num_parts_range": [2, 10],
"vertices_per_linestring_range": [2, 10],
"bounds": west_most_bound,
"size_range": [0.1, 5],
"seed": 42,
}
)
df_point = eng_sedonadb.execute_and_collect(
f"SELECT id, ST_SetSRID(ST_GeogFromWKB(ST_AsBinary(geometry)), 4326) geog, dist FROM sd_random_geometry('{options}') LIMIT 100"
)
options = json.dumps(
{
"geom_type": "Polygon",
"polygon_hole_rate": 0.5,
"num_parts_range": [2, 10],
"vertices_per_linestring_range": [2, 10],
"bounds": east_most_bound,
"size_range": [0.1, 5],
"seed": 43,
}
)
df_polygon = eng_sedonadb.execute_and_collect(
f"SELECT id, ST_SetSRID(ST_GeogFromWKB(ST_AsBinary(geometry)), 4326) geog, dist FROM sd_random_geometry('{options}') LIMIT 100"
)
eng_sedonadb.create_table_arrow("sjoin_geog1", df_point)
eng_sedonadb.create_table_arrow("sjoin_geog2", df_polygon)
eng_postgis.create_table_arrow("sjoin_geog1", df_point)
eng_postgis.create_table_arrow("sjoin_geog2", df_polygon)

sql = f"""
SELECT sjoin_geog1.id id0, sjoin_geog2.id id1
FROM sjoin_geog1 {join_type} sjoin_geog2
ON {on}
ORDER BY id0, id1
"""

sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas()
eng_postgis.assert_query_result(sql, sedonadb_results)
141 changes: 138 additions & 3 deletions rust/sedona-expr/src/spatial_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_physical_expr::{
use geo_traits::Dimensions;
use sedona_common::sedona_internal_err;
use sedona_geometry::{bounding_box::BoundingBox, bounds::wkb_bounds_xy, interval::IntervalTrait};
use sedona_schema::datatypes::SedonaType;
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};

use crate::{
statistics::GeoStatistics,
Expand Down Expand Up @@ -185,6 +185,9 @@ impl SpatialFilter {
match (&args[0], &args[1]) {
(ArgRef::Col(column), ArgRef::Lit(literal))
| (ArgRef::Lit(literal), ArgRef::Col(column)) => {
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match literal_bounds(literal) {
Ok(literal_bounds) => {
Ok(Some(Self::Intersects(column.clone(), literal_bounds)))
Expand All @@ -204,6 +207,9 @@ impl SpatialFilter {
match (&args[0], &args[1]) {
(ArgRef::Col(column), ArgRef::Lit(literal)) => {
// column within/covered_by literal -> Intersects filter
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match literal_bounds(literal) {
Ok(literal_bounds) => {
Ok(Some(Self::Intersects(column.clone(), literal_bounds)))
Expand All @@ -213,6 +219,9 @@ impl SpatialFilter {
}
(ArgRef::Lit(literal), ArgRef::Col(column)) => {
// literal within/covered_by column -> Covers filter
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match literal_bounds(literal) {
Ok(literal_bounds) => {
Ok(Some(Self::Covers(column.clone(), literal_bounds)))
Expand All @@ -233,6 +242,9 @@ impl SpatialFilter {
(ArgRef::Col(column), ArgRef::Lit(literal)) => {
// column contains/covers literal -> Covers filter
// (column's bbox must fully cover literal's bbox)
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match literal_bounds(literal) {
Ok(literal_bounds) => {
Ok(Some(Self::Covers(column.clone(), literal_bounds)))
Expand All @@ -243,6 +255,9 @@ impl SpatialFilter {
(ArgRef::Lit(literal), ArgRef::Col(column)) => {
// literal contains/covers column -> Intersects filter
// (if literal contains column, they must at least intersect)
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match literal_bounds(literal) {
Ok(literal_bounds) => {
Ok(Some(Self::Intersects(column.clone(), literal_bounds)))
Expand Down Expand Up @@ -284,6 +299,9 @@ impl SpatialFilter {
match (&args[0], &args[1], &args[2]) {
(ArgRef::Col(column), ArgRef::Lit(literal), ArgRef::Lit(distance))
| (ArgRef::Lit(literal), ArgRef::Col(column), ArgRef::Lit(distance)) => {
if !is_prunable_geospatial_literal(literal) {
return Ok(Some(Self::Unknown));
}
match (
literal_bounds(literal),
distance.value().cast_to(&DataType::Float64)?,
Expand Down Expand Up @@ -314,6 +332,19 @@ enum ArgRef<'a> {
Other,
}

/// Our current spatial data pruning implementation does not correctly handle geography data.
/// We therefore only consider geometry data type for pruning.
fn is_prunable_geospatial_literal(literal: &Literal) -> bool {
let Ok(literal_field) = literal.return_field(&Schema::empty()) else {
return false;
};
let Ok(sedona_type) = SedonaType::from_storage_field(&literal_field) else {
return false;
};
let matcher = ArgMatcher::is_geometry();
matcher.match_type(&sedona_type)
}

fn literal_bounds(literal: &Literal) -> Result<BoundingBox> {
let literal_field = literal.return_field(&Schema::empty())?;
let sedona_type = SedonaType::from_storage_field(&literal_field)?;
Expand Down Expand Up @@ -348,12 +379,11 @@ fn parse_args(args: &[Arc<dyn PhysicalExpr>]) -> Vec<ArgRef<'_>> {

#[cfg(test)]
mod test {

use arrow_schema::{DataType, Field};
use datafusion_expr::{ScalarUDF, Signature, SimpleScalarUDF, Volatility};
use rstest::rstest;
use sedona_geometry::{bounding_box::BoundingBox, interval::Interval};
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_schema::datatypes::{WKB_GEOGRAPHY, WKB_GEOMETRY};
use sedona_testing::create::create_scalar;

use super::*;
Expand Down Expand Up @@ -806,6 +836,111 @@ mod test {
));
}

#[rstest]
fn range_predicate_involving_geography_should_be_transformed_to_unknown(
#[values(
"st_intersects",
"st_equals",
"st_touches",
"st_contains",
"st_covers",
"st_within",
"st_covered_by",
"st_coveredby"
)]
func_name: &str,
) {
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry", 0));
let storage_field = WKB_GEOGRAPHY.to_storage_field("", true).unwrap();
let literal: Arc<dyn PhysicalExpr> = Arc::new(Literal::new_with_metadata(
create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOGRAPHY),
Some(storage_field.metadata().into()),
));

let func = create_dummy_spatial_function(func_name, 2);
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
func_name,
Arc::new(func.clone()),
vec![column.clone(), literal.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
let predicate = SpatialFilter::try_from_expr(&expr).unwrap();
assert!(
matches!(predicate, SpatialFilter::Unknown),
"Function {func_name} involving geography should produce Unknown filter"
);
}

#[test]
fn distance_predicate_involving_geography_should_be_transformed_to_unknown() {
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry", 0));
let storage_field = WKB_GEOGRAPHY.to_storage_field("", true).unwrap();
let literal: Arc<dyn PhysicalExpr> = Arc::new(Literal::new_with_metadata(
create_scalar(Some("POINT (1 2)"), &WKB_GEOGRAPHY),
Some(storage_field.metadata().into()),
));
let distance_literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Float64(Some(100.0))));

// Test ST_DWithin function
let st_dwithin = create_dummy_spatial_function("st_dwithin", 3);
let dwithin_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
"st_dwithin",
Arc::new(st_dwithin.clone()),
vec![column.clone(), literal.clone(), distance_literal.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
let predicate = SpatialFilter::try_from_expr(&dwithin_expr).unwrap();
assert!(
matches!(predicate, SpatialFilter::Unknown),
"ST_DWithin involving geography should produce Unknown filter"
);

// Test ST_DWithin with reversed geometry arguments
let dwithin_expr_reversed: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
"st_dwithin",
Arc::new(st_dwithin),
vec![literal.clone(), column.clone(), distance_literal.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
let predicate_reversed = SpatialFilter::try_from_expr(&dwithin_expr_reversed).unwrap();
assert!(
matches!(predicate_reversed, SpatialFilter::Unknown),
"ST_DWithin involving geography should produce Unknown filter"
);

// Test ST_Distance <= threshold
let st_distance = create_dummy_spatial_function("st_distance", 2);
let distance_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
"st_distance",
Arc::new(st_distance.clone()),
vec![column.clone(), literal.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
let comparison_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
distance_expr.clone(),
Operator::LtEq,
distance_literal.clone(),
));
let predicate = SpatialFilter::try_from_expr(&comparison_expr).unwrap();
assert!(
matches!(predicate, SpatialFilter::Unknown),
"ST_Distance <= threshold involving geography should produce Unknown filter"
);

// Test threshold >= ST_Distance
let comparison_expr_reversed: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
distance_literal.clone(),
Operator::GtEq,
distance_expr.clone(),
));
let predicate_reversed = SpatialFilter::try_from_expr(&comparison_expr_reversed).unwrap();
assert!(
matches!(predicate_reversed, SpatialFilter::Unknown),
"threshold >= ST_Distance involving geography should produce Unknown filter"
);
}

#[test]
fn predicate_from_expr_has_z() {
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry", 0));
Expand Down
42 changes: 37 additions & 5 deletions rust/sedona-spatial-join/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ mod tests {
use geo_types::{Coord, Rect};
use rstest::rstest;
use sedona_geometry::types::GeometryTypeId;
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
use sedona_testing::datagen::RandomPartitionedDataBuilder;
use tokio::sync::OnceCell;

Expand All @@ -649,12 +649,13 @@ mod tests {

/// Creates standard test data with left (Polygon) and right (Point) partitions
fn create_default_test_data() -> Result<(TestPartitions, TestPartitions)> {
create_test_data_with_size_range((1.0, 10.0))
create_test_data_with_size_range((1.0, 10.0), WKB_GEOMETRY)
}

/// Creates test data with custom size range
fn create_test_data_with_size_range(
size_range: (f64, f64),
sedona_type: SedonaType,
) -> Result<(TestPartitions, TestPartitions)> {
let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 100.0 });

Expand All @@ -664,7 +665,7 @@ mod tests {
.batches_per_partition(2)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Polygon)
.sedona_type(WKB_GEOMETRY)
.sedona_type(sedona_type.clone())
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
Expand All @@ -676,7 +677,7 @@ mod tests {
.batches_per_partition(4)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Point)
.sedona_type(WKB_GEOMETRY)
.sedona_type(sedona_type)
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
Expand Down Expand Up @@ -928,7 +929,7 @@ mod tests {
#[tokio::test]
async fn test_spatial_join_with_filter() -> Result<()> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_size_range((0.1, 10.0))?;
create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;

for max_batch_size in [10, 30, 100] {
let options = SpatialJoinOptions {
Expand Down Expand Up @@ -996,6 +997,37 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_geography_join_is_not_optimized() -> Result<()> {
let options = SpatialJoinOptions::default();
let ctx = setup_context(Some(options), 10)?;

// Prepare geography tables
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_size_range((0.1, 10.0), WKB_GEOGRAPHY)?;
let mem_table_left: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(left_schema, left_partitions)?);
let mem_table_right: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(right_schema, right_partitions)?);
ctx.register_table("L", mem_table_left)?;
ctx.register_table("R", mem_table_right)?;

// Execute geography join query
let df = ctx
.sql("SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry)")
.await?;
let plan = df.create_physical_plan().await?;

// Verify that no SpatialJoinExec is present (geography join should not be optimized)
let spatial_joins = collect_spatial_join_exec(&plan)?;
assert!(
spatial_joins.is_empty(),
"Geography joins should not be optimized to SpatialJoinExec"
);

Ok(())
}

async fn test_with_join_types(join_type: JoinType) -> Result<RecordBatch> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_empty_partitions()?;
Expand Down
Loading