@@ -171,6 +171,7 @@ impl SpatialJoinExec {
171171 let cache = Self :: compute_properties (
172172 & left,
173173 & right,
174+ & on,
174175 Arc :: clone ( & join_schema) ,
175176 * join_type,
176177 projection. as_ref ( ) ,
@@ -236,9 +237,11 @@ impl SpatialJoinExec {
236237 ///
237238 /// When converted from HashJoin, we preserve HashJoin's equivalence properties by extracting
238239 /// equality conditions from the filter.
240+ #[ allow( clippy:: too_many_arguments) ]
239241 fn compute_properties (
240242 left : & Arc < dyn ExecutionPlan > ,
241243 right : & Arc < dyn ExecutionPlan > ,
244+ on : & SpatialPredicate ,
242245 schema : SchemaRef ,
243246 join_type : JoinType ,
244247 projection : Option < & Vec < usize > > ,
@@ -265,7 +268,13 @@ impl SpatialJoinExec {
265268
266269 // Use symmetric partitioning (like HashJoin) when converted from HashJoin
267270 // Otherwise use asymmetric partitioning (like NestedLoopJoin)
268- let mut output_partitioning = if converted_from_hash_join {
271+ let mut output_partitioning = if let SpatialPredicate :: KNearestNeighbors ( knn) = on {
272+ match knn. probe_side {
273+ JoinSide :: Left => left. output_partitioning ( ) . clone ( ) ,
274+ JoinSide :: Right => right. output_partitioning ( ) . clone ( ) ,
275+ _ => asymmetric_join_output_partitioning ( left, right, & join_type) ,
276+ }
277+ } else if converted_from_hash_join {
269278 // Replicate HashJoin's symmetric partitioning logic
270279 // HashJoin preserves partitioning from both sides for inner joins
271280 // and from one side for outer joins
@@ -467,7 +476,6 @@ impl ExecutionPlan for SpatialJoinExec {
467476 } ) ?
468477 } ;
469478
470- // Column indices for regular joins - no swapping needed
471479 let column_indices_after_projection = match & self . projection {
472480 Some ( projection) => projection
473481 . iter ( )
@@ -559,30 +567,14 @@ impl SpatialJoinExec {
559567 } ) ?
560568 } ;
561569
562- // Handle column indices for KNN - need to swap if we swapped execution plans
563- let mut column_indices_after_projection = match & self . projection {
570+ let column_indices_after_projection = match & self . projection {
564571 Some ( projection) => projection
565572 . iter ( )
566573 . map ( |i| self . column_indices [ * i] . clone ( ) )
567574 . collect ( ) ,
568575 None => self . column_indices . clone ( ) ,
569576 } ;
570577
571- // If we swapped execution plans for KNN, we need to swap the column indices too
572- if !actual_probe_plan_is_left {
573- for col_idx in & mut column_indices_after_projection {
574- match col_idx. side {
575- datafusion_common:: JoinSide :: Left => {
576- col_idx. side = datafusion_common:: JoinSide :: Right
577- }
578- datafusion_common:: JoinSide :: Right => {
579- col_idx. side = datafusion_common:: JoinSide :: Left
580- }
581- datafusion_common:: JoinSide :: None => { } // No change needed
582- }
583- }
584- }
585-
586578 let join_metrics = SpatialJoinProbeMetrics :: new ( partition, & self . metrics ) ;
587579 let probe_stream = probe_plan. execute ( partition, Arc :: clone ( & context) ) ?;
588580
@@ -614,20 +606,23 @@ impl SpatialJoinExec {
614606
615607#[ cfg( test) ]
616608mod tests {
617- use arrow_array:: RecordBatch ;
609+ use arrow_array:: { Array , RecordBatch } ;
618610 use arrow_schema:: { DataType , Field , Schema } ;
619611 use datafusion:: {
620612 catalog:: { MemTable , TableProvider } ,
621613 execution:: SessionStateBuilder ,
622614 prelude:: { SessionConfig , SessionContext } ,
623615 } ;
624616 use datafusion_common:: tree_node:: { TreeNode , TreeNodeRecursion } ;
617+ use geo:: { Distance , Euclidean } ;
625618 use geo_types:: { Coord , Rect } ;
626619 use rstest:: rstest;
620+ use sedona_geo:: to_geo:: item_to_geometry;
627621 use sedona_geometry:: types:: GeometryTypeId ;
628622 use sedona_schema:: datatypes:: { SedonaType , WKB_GEOGRAPHY , WKB_GEOMETRY } ;
629623 use sedona_testing:: datagen:: RandomPartitionedDataBuilder ;
630624 use tokio:: sync:: OnceCell ;
625+ use wkb:: reader:: read_wkb;
631626
632627 use crate :: register_spatial_join_optimizer;
633628 use sedona_common:: {
@@ -691,6 +686,40 @@ mod tests {
691686 Ok ( ( left_data, right_data) )
692687 }
693688
689+ /// Creates test data for KNN join (Point-Point)
690+ fn create_knn_test_data (
691+ size_range : ( f64 , f64 ) ,
692+ sedona_type : SedonaType ,
693+ ) -> Result < ( TestPartitions , TestPartitions ) > {
694+ let bounds = Rect :: new ( Coord { x : 0.0 , y : 0.0 } , Coord { x : 100.0 , y : 100.0 } ) ;
695+
696+ let left_data = RandomPartitionedDataBuilder :: new ( )
697+ . seed ( 1 )
698+ . num_partitions ( 2 )
699+ . batches_per_partition ( 2 )
700+ . rows_per_batch ( 30 )
701+ . geometry_type ( GeometryTypeId :: Point )
702+ . sedona_type ( sedona_type. clone ( ) )
703+ . bounds ( bounds)
704+ . size_range ( size_range)
705+ . null_rate ( 0.1 )
706+ . build ( ) ?;
707+
708+ let right_data = RandomPartitionedDataBuilder :: new ( )
709+ . seed ( 2 )
710+ . num_partitions ( 4 )
711+ . batches_per_partition ( 4 )
712+ . rows_per_batch ( 30 )
713+ . geometry_type ( GeometryTypeId :: Point )
714+ . sedona_type ( sedona_type)
715+ . bounds ( bounds)
716+ . size_range ( size_range)
717+ . null_rate ( 0.1 )
718+ . build ( ) ?;
719+
720+ Ok ( ( left_data, right_data) )
721+ }
722+
694723 fn setup_context (
695724 options : Option < SpatialJoinOptions > ,
696725 batch_size : usize ,
@@ -1173,4 +1202,157 @@ mod tests {
11731202 } ) ?;
11741203 Ok ( spatial_join_execs)
11751204 }
1205+
1206+ fn extract_geoms_and_ids ( partitions : & [ Vec < RecordBatch > ] ) -> Vec < ( i32 , geo:: Geometry < f64 > ) > {
1207+ let mut result = Vec :: new ( ) ;
1208+ for partition in partitions {
1209+ for batch in partition {
1210+ let id_idx = batch. schema ( ) . index_of ( "id" ) . expect ( "Id column not found" ) ;
1211+ let ids = batch
1212+ . column ( id_idx)
1213+ . as_any ( )
1214+ . downcast_ref :: < arrow_array:: Int32Array > ( )
1215+ . expect ( "Column 'id' should be Int32" ) ;
1216+
1217+ let geom_idx = batch
1218+ . schema ( )
1219+ . index_of ( "geometry" )
1220+ . expect ( "Geometry column not found" ) ;
1221+ let geoms_col = batch. column ( geom_idx) ;
1222+ let geoms_binary = geoms_col
1223+ . as_any ( )
1224+ . downcast_ref :: < arrow_array:: BinaryArray > ( ) ;
1225+ let geoms_binary_view = geoms_col
1226+ . as_any ( )
1227+ . downcast_ref :: < arrow_array:: BinaryViewArray > ( ) ;
1228+
1229+ if geoms_binary. is_none ( ) && geoms_binary_view. is_none ( ) {
1230+ panic ! (
1231+ "Column 'geometry' should be Binary or BinaryView. Schema: {:?}" ,
1232+ batch. schema( )
1233+ ) ;
1234+ }
1235+
1236+ for i in 0 ..batch. num_rows ( ) {
1237+ if ids. is_null ( i) {
1238+ continue ;
1239+ }
1240+ let id = ids. value ( i) ;
1241+
1242+ let geom_bytes = if let Some ( arr) = geoms_binary {
1243+ if arr. is_null ( i) {
1244+ continue ;
1245+ }
1246+ arr. value ( i)
1247+ } else {
1248+ let arr = geoms_binary_view. unwrap ( ) ;
1249+ if arr. is_null ( i) {
1250+ continue ;
1251+ }
1252+ arr. value ( i)
1253+ } ;
1254+
1255+ let geom_wkb = read_wkb ( & mut & * geom_bytes) . expect ( "Failed to parse WKB" ) ;
1256+ let geom = item_to_geometry ( geom_wkb) . expect ( "Failed to parse WKB" ) ;
1257+ result. push ( ( id, geom) ) ;
1258+ }
1259+ }
1260+ }
1261+ result
1262+ }
1263+
1264+ fn compute_knn_ground_truth (
1265+ left_partitions : & [ Vec < RecordBatch > ] ,
1266+ right_partitions : & [ Vec < RecordBatch > ] ,
1267+ k : usize ,
1268+ ) -> Vec < ( i32 , i32 , f64 ) > {
1269+ let left_data = extract_geoms_and_ids ( left_partitions) ;
1270+ let right_data = extract_geoms_and_ids ( right_partitions) ;
1271+
1272+ let mut results = Vec :: new ( ) ;
1273+
1274+ for ( l_id, l_geom) in left_data {
1275+ let mut distances: Vec < ( i32 , f64 ) > = right_data
1276+ . iter ( )
1277+ . map ( |( r_id, r_geom) | ( * r_id, Euclidean . distance ( & l_geom, r_geom) ) )
1278+ . collect ( ) ;
1279+
1280+ // Sort by distance, then by ID for stability
1281+ distances. sort_by ( |a, b| a. 1 . total_cmp ( & b. 1 ) . then_with ( || a. 0 . cmp ( & b. 0 ) ) ) ;
1282+
1283+ for i in 0 ..k. min ( distances. len ( ) ) {
1284+ results. push ( ( l_id, distances[ i] . 0 , distances[ i] . 1 ) ) ;
1285+ }
1286+ }
1287+
1288+ // Sort results by L.id, R.id
1289+ results. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) . then_with ( || a. 1 . cmp ( & b. 1 ) ) ) ;
1290+ results
1291+ }
1292+
1293+ #[ tokio:: test]
1294+ async fn test_knn_join_correctness ( ) -> Result < ( ) > {
1295+ // Generate slightly larger data
1296+ let ( ( left_schema, left_partitions) , ( right_schema, right_partitions) ) =
1297+ create_knn_test_data ( ( 0.1 , 10.0 ) , WKB_GEOMETRY ) ?;
1298+
1299+ let options = SpatialJoinOptions :: default ( ) ;
1300+ let k = 3 ;
1301+
1302+ let sql1 = format ! (
1303+ "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id" ,
1304+ k
1305+ ) ;
1306+ let expected1 = compute_knn_ground_truth ( & left_partitions, & right_partitions, k)
1307+ . into_iter ( )
1308+ . map ( |( l, r, _) | ( l, r) )
1309+ . collect :: < Vec < _ > > ( ) ;
1310+
1311+ let sql2 = format ! (
1312+ "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id" ,
1313+ k
1314+ ) ;
1315+ let expected2 = compute_knn_ground_truth ( & right_partitions, & left_partitions, k)
1316+ . into_iter ( )
1317+ . map ( |( l, r, _) | ( l, r) )
1318+ . collect :: < Vec < _ > > ( ) ;
1319+
1320+ let sqls = [ ( & sql1, & expected1) , ( & sql2, & expected2) ] ;
1321+
1322+ for ( sql, expected_results) in sqls {
1323+ let batches = run_spatial_join_query (
1324+ & left_schema,
1325+ & right_schema,
1326+ left_partitions. clone ( ) ,
1327+ right_partitions. clone ( ) ,
1328+ Some ( options. clone ( ) ) ,
1329+ 10 ,
1330+ & sql,
1331+ )
1332+ . await ?;
1333+
1334+ // Collect actual results
1335+ let mut actual_results = Vec :: new ( ) ;
1336+ let combined_batch = arrow:: compute:: concat_batches ( & batches. schema ( ) , & [ batches] ) ?;
1337+ let l_ids = combined_batch
1338+ . column ( 0 )
1339+ . as_any ( )
1340+ . downcast_ref :: < arrow_array:: Int32Array > ( )
1341+ . unwrap ( ) ;
1342+ let r_ids = combined_batch
1343+ . column ( 1 )
1344+ . as_any ( )
1345+ . downcast_ref :: < arrow_array:: Int32Array > ( )
1346+ . unwrap ( ) ;
1347+
1348+ for i in 0 ..combined_batch. num_rows ( ) {
1349+ actual_results. push ( ( l_ids. value ( i) , r_ids. value ( i) ) ) ;
1350+ }
1351+ actual_results. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) . then_with ( || a. 1 . cmp ( & b. 1 ) ) ) ;
1352+
1353+ assert_eq ! ( actual_results, * expected_results) ;
1354+ }
1355+
1356+ Ok ( ( ) )
1357+ }
11761358}
0 commit comments