diff --git a/rust/sedona-spatial-join/src/optimizer.rs b/rust/sedona-spatial-join/src/optimizer.rs index 6440da264..0a0ea5486 100644 --- a/rust/sedona-spatial-join/src/optimizer.rs +++ b/rust/sedona-spatial-join/src/optimizer.rs @@ -147,7 +147,7 @@ impl SpatialJoinOptimizer { &spatial_predicate, &left.schema(), &right.schema(), - ) { + )? { return Ok(None); } @@ -190,7 +190,7 @@ impl SpatialJoinOptimizer { &spatial_predicate, &hash_join.left().schema(), &hash_join.right().schema(), - ) { + )? { return Ok(None); } @@ -877,26 +877,40 @@ fn is_spatial_predicate_supported( spatial_predicate: &SpatialPredicate, left_schema: &Schema, right_schema: &Schema, -) -> bool { +) -> Result { /// Only spatial predicates working with planar geometry are supported for optimization. /// Geography (spherical) types are explicitly excluded and will not trigger optimized spatial joins. - fn is_geometry_type_supported(expr: &Arc, schema: &Schema) -> bool { - let Ok(left_return_field) = expr.return_field(schema) else { - return false; - }; - let Ok(sedona_type) = SedonaType::from_storage_field(&left_return_field) else { - return false; - }; + fn is_geometry_type_supported(expr: &Arc, schema: &Schema) -> Result { + let left_return_field = expr.return_field(schema)?; + let sedona_type = SedonaType::from_storage_field(&left_return_field)?; let matcher = ArgMatcher::is_geometry(); - matcher.match_type(&sedona_type) + Ok(matcher.match_type(&sedona_type)) } match spatial_predicate { SpatialPredicate::Relation(RelationPredicate { left, right, .. }) - | SpatialPredicate::Distance(DistancePredicate { left, right, .. }) - | SpatialPredicate::KNearestNeighbors(KNNPredicate { left, right, .. }) => { - is_geometry_type_supported(left, left_schema) - && is_geometry_type_supported(right, right_schema) + | SpatialPredicate::Distance(DistancePredicate { left, right, .. }) => { + Ok(is_geometry_type_supported(left, left_schema)? + && is_geometry_type_supported(right, right_schema)?) + } + SpatialPredicate::KNearestNeighbors(KNNPredicate { + left, + right, + probe_side, + .. + }) => { + let (left, right) = match probe_side { + JoinSide::Left => (left, right), + JoinSide::Right => (right, left), + _ => { + return sedona_internal_err!( + "Invalid probe side in KNN predicate: {:?}", + probe_side + ) + } + }; + Ok(is_geometry_type_supported(left, left_schema)? + && is_geometry_type_supported(right, right_schema)?) } } } @@ -2479,11 +2493,7 @@ mod tests { SpatialRelationType::Intersects, ); let spatial_pred = SpatialPredicate::Relation(rel_pred); - assert!(super::is_spatial_predicate_supported( - &spatial_pred, - &schema, - &schema - )); + assert!(super::is_spatial_predicate_supported(&spatial_pred, &schema, &schema).unwrap()); // Geography field (should NOT be supported) let geog_field = WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap(); @@ -2499,6 +2509,65 @@ mod tests { &spatial_pred_geog, &geog_schema, &geog_schema + ) + .unwrap()); + } + + #[test] + fn test_is_knn_predicate_supported() { + // ST_KNN(left, right) + let left_schema = Arc::new(Schema::new(vec![WKB_GEOMETRY + .to_storage_field("geom", false) + .unwrap()])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + WKB_GEOMETRY.to_storage_field("geom", false).unwrap(), + ])); + let left_col_expr = Arc::new(Column::new("geom", 0)) as Arc; + let right_col_expr = Arc::new(Column::new("geom", 1)) as Arc; + let knn_pred = SpatialPredicate::KNearestNeighbors(KNNPredicate::new( + left_col_expr.clone(), + right_col_expr.clone(), + 5, + false, + JoinSide::Left, + )); + assert!( + super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap() + ); + + // ST_KNN(right, left) + let knn_pred = SpatialPredicate::KNearestNeighbors(KNNPredicate::new( + right_col_expr.clone(), + left_col_expr.clone(), + 5, + false, + JoinSide::Right, )); + assert!( + super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap() + ); + + // ST_KNN with geography (should NOT be supported) + let left_geog_schema = Arc::new(Schema::new(vec![WKB_GEOGRAPHY + .to_storage_field("geog", false) + .unwrap()])); + assert!(!super::is_spatial_predicate_supported( + &knn_pred, + &left_geog_schema, + &right_schema + ) + .unwrap()); + + let right_geog_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap(), + ])); + assert!(!super::is_spatial_predicate_supported( + &knn_pred, + &left_schema, + &right_geog_schema + ) + .unwrap()); } }