Skip to content

Commit 54489c4

Browse files
committed
Fix several bugs related to KNN join
1 parent 9bf02f4 commit 54489c4

3 files changed

Lines changed: 259 additions & 45 deletions

File tree

rust/sedona-spatial-join/src/exec.rs

Lines changed: 202 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ impl SpatialJoinExec {
171171
let cache = Self::compute_properties(
172172
&left,
173173
&right,
174+
&on,
174175
Arc::clone(&join_schema),
175176
*join_type,
176177
projection.as_ref(),
@@ -236,9 +237,11 @@ impl SpatialJoinExec {
236237
///
237238
/// When converted from HashJoin, we preserve HashJoin's equivalence properties by extracting
238239
/// equality conditions from the filter.
240+
#[allow(clippy::too_many_arguments)]
239241
fn compute_properties(
240242
left: &Arc<dyn ExecutionPlan>,
241243
right: &Arc<dyn ExecutionPlan>,
244+
on: &SpatialPredicate,
242245
schema: SchemaRef,
243246
join_type: JoinType,
244247
projection: Option<&Vec<usize>>,
@@ -265,7 +268,13 @@ impl SpatialJoinExec {
265268

266269
// Use symmetric partitioning (like HashJoin) when converted from HashJoin
267270
// Otherwise use asymmetric partitioning (like NestedLoopJoin)
268-
let mut output_partitioning = if converted_from_hash_join {
271+
let mut output_partitioning = if let SpatialPredicate::KNearestNeighbors(knn) = on {
272+
match knn.probe_side {
273+
JoinSide::Left => left.output_partitioning().clone(),
274+
JoinSide::Right => right.output_partitioning().clone(),
275+
_ => asymmetric_join_output_partitioning(left, right, &join_type),
276+
}
277+
} else if converted_from_hash_join {
269278
// Replicate HashJoin's symmetric partitioning logic
270279
// HashJoin preserves partitioning from both sides for inner joins
271280
// and from one side for outer joins
@@ -467,7 +476,6 @@ impl ExecutionPlan for SpatialJoinExec {
467476
})?
468477
};
469478

470-
// Column indices for regular joins - no swapping needed
471479
let column_indices_after_projection = match &self.projection {
472480
Some(projection) => projection
473481
.iter()
@@ -559,30 +567,14 @@ impl SpatialJoinExec {
559567
})?
560568
};
561569

562-
// Handle column indices for KNN - need to swap if we swapped execution plans
563-
let mut column_indices_after_projection = match &self.projection {
570+
let column_indices_after_projection = match &self.projection {
564571
Some(projection) => projection
565572
.iter()
566573
.map(|i| self.column_indices[*i].clone())
567574
.collect(),
568575
None => self.column_indices.clone(),
569576
};
570577

571-
// If we swapped execution plans for KNN, we need to swap the column indices too
572-
if !actual_probe_plan_is_left {
573-
for col_idx in &mut column_indices_after_projection {
574-
match col_idx.side {
575-
datafusion_common::JoinSide::Left => {
576-
col_idx.side = datafusion_common::JoinSide::Right
577-
}
578-
datafusion_common::JoinSide::Right => {
579-
col_idx.side = datafusion_common::JoinSide::Left
580-
}
581-
datafusion_common::JoinSide::None => {} // No change needed
582-
}
583-
}
584-
}
585-
586578
let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics);
587579
let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?;
588580

@@ -614,20 +606,23 @@ impl SpatialJoinExec {
614606

615607
#[cfg(test)]
616608
mod tests {
617-
use arrow_array::RecordBatch;
609+
use arrow_array::{Array, RecordBatch};
618610
use arrow_schema::{DataType, Field, Schema};
619611
use datafusion::{
620612
catalog::{MemTable, TableProvider},
621613
execution::SessionStateBuilder,
622614
prelude::{SessionConfig, SessionContext},
623615
};
624616
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
617+
use geo::{Distance, Euclidean};
625618
use geo_types::{Coord, Rect};
626619
use rstest::rstest;
620+
use sedona_geo::to_geo::item_to_geometry;
627621
use sedona_geometry::types::GeometryTypeId;
628622
use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
629623
use sedona_testing::datagen::RandomPartitionedDataBuilder;
630624
use tokio::sync::OnceCell;
625+
use wkb::reader::read_wkb;
631626

632627
use crate::register_spatial_join_optimizer;
633628
use sedona_common::{
@@ -691,6 +686,40 @@ mod tests {
691686
Ok((left_data, right_data))
692687
}
693688

689+
/// Creates test data for KNN join (Point-Point)
690+
fn create_knn_test_data(
691+
size_range: (f64, f64),
692+
sedona_type: SedonaType,
693+
) -> Result<(TestPartitions, TestPartitions)> {
694+
let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 100.0 });
695+
696+
let left_data = RandomPartitionedDataBuilder::new()
697+
.seed(1)
698+
.num_partitions(2)
699+
.batches_per_partition(2)
700+
.rows_per_batch(30)
701+
.geometry_type(GeometryTypeId::Point)
702+
.sedona_type(sedona_type.clone())
703+
.bounds(bounds)
704+
.size_range(size_range)
705+
.null_rate(0.1)
706+
.build()?;
707+
708+
let right_data = RandomPartitionedDataBuilder::new()
709+
.seed(2)
710+
.num_partitions(4)
711+
.batches_per_partition(4)
712+
.rows_per_batch(30)
713+
.geometry_type(GeometryTypeId::Point)
714+
.sedona_type(sedona_type)
715+
.bounds(bounds)
716+
.size_range(size_range)
717+
.null_rate(0.1)
718+
.build()?;
719+
720+
Ok((left_data, right_data))
721+
}
722+
694723
fn setup_context(
695724
options: Option<SpatialJoinOptions>,
696725
batch_size: usize,
@@ -1173,4 +1202,157 @@ mod tests {
11731202
})?;
11741203
Ok(spatial_join_execs)
11751204
}
1205+
1206+
fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, geo::Geometry<f64>)> {
1207+
let mut result = Vec::new();
1208+
for partition in partitions {
1209+
for batch in partition {
1210+
let id_idx = batch.schema().index_of("id").expect("Id column not found");
1211+
let ids = batch
1212+
.column(id_idx)
1213+
.as_any()
1214+
.downcast_ref::<arrow_array::Int32Array>()
1215+
.expect("Column 'id' should be Int32");
1216+
1217+
let geom_idx = batch
1218+
.schema()
1219+
.index_of("geometry")
1220+
.expect("Geometry column not found");
1221+
let geoms_col = batch.column(geom_idx);
1222+
let geoms_binary = geoms_col
1223+
.as_any()
1224+
.downcast_ref::<arrow_array::BinaryArray>();
1225+
let geoms_binary_view = geoms_col
1226+
.as_any()
1227+
.downcast_ref::<arrow_array::BinaryViewArray>();
1228+
1229+
if geoms_binary.is_none() && geoms_binary_view.is_none() {
1230+
panic!(
1231+
"Column 'geometry' should be Binary or BinaryView. Schema: {:?}",
1232+
batch.schema()
1233+
);
1234+
}
1235+
1236+
for i in 0..batch.num_rows() {
1237+
if ids.is_null(i) {
1238+
continue;
1239+
}
1240+
let id = ids.value(i);
1241+
1242+
let geom_bytes = if let Some(arr) = geoms_binary {
1243+
if arr.is_null(i) {
1244+
continue;
1245+
}
1246+
arr.value(i)
1247+
} else {
1248+
let arr = geoms_binary_view.unwrap();
1249+
if arr.is_null(i) {
1250+
continue;
1251+
}
1252+
arr.value(i)
1253+
};
1254+
1255+
let geom_wkb = read_wkb(&mut &*geom_bytes).expect("Failed to parse WKB");
1256+
let geom = item_to_geometry(geom_wkb).expect("Failed to parse WKB");
1257+
result.push((id, geom));
1258+
}
1259+
}
1260+
}
1261+
result
1262+
}
1263+
1264+
fn compute_knn_ground_truth(
1265+
left_partitions: &[Vec<RecordBatch>],
1266+
right_partitions: &[Vec<RecordBatch>],
1267+
k: usize,
1268+
) -> Vec<(i32, i32, f64)> {
1269+
let left_data = extract_geoms_and_ids(left_partitions);
1270+
let right_data = extract_geoms_and_ids(right_partitions);
1271+
1272+
let mut results = Vec::new();
1273+
1274+
for (l_id, l_geom) in left_data {
1275+
let mut distances: Vec<(i32, f64)> = right_data
1276+
.iter()
1277+
.map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom, r_geom)))
1278+
.collect();
1279+
1280+
// Sort by distance, then by ID for stability
1281+
distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
1282+
1283+
for i in 0..k.min(distances.len()) {
1284+
results.push((l_id, distances[i].0, distances[i].1));
1285+
}
1286+
}
1287+
1288+
// Sort results by L.id, R.id
1289+
results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
1290+
results
1291+
}
1292+
1293+
#[tokio::test]
1294+
async fn test_knn_join_correctness() -> Result<()> {
1295+
// Generate slightly larger data
1296+
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
1297+
create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;
1298+
1299+
let options = SpatialJoinOptions::default();
1300+
let k = 3;
1301+
1302+
let sql1 = format!(
1303+
"SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
1304+
k
1305+
);
1306+
let expected1 = compute_knn_ground_truth(&left_partitions, &right_partitions, k)
1307+
.into_iter()
1308+
.map(|(l, r, _)| (l, r))
1309+
.collect::<Vec<_>>();
1310+
1311+
let sql2 = format!(
1312+
"SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
1313+
k
1314+
);
1315+
let expected2 = compute_knn_ground_truth(&right_partitions, &left_partitions, k)
1316+
.into_iter()
1317+
.map(|(l, r, _)| (l, r))
1318+
.collect::<Vec<_>>();
1319+
1320+
let sqls = [(&sql1, &expected1), (&sql2, &expected2)];
1321+
1322+
for (sql, expected_results) in sqls {
1323+
let batches = run_spatial_join_query(
1324+
&left_schema,
1325+
&right_schema,
1326+
left_partitions.clone(),
1327+
right_partitions.clone(),
1328+
Some(options.clone()),
1329+
10,
1330+
&sql,
1331+
)
1332+
.await?;
1333+
1334+
// Collect actual results
1335+
let mut actual_results = Vec::new();
1336+
let combined_batch = arrow::compute::concat_batches(&batches.schema(), &[batches])?;
1337+
let l_ids = combined_batch
1338+
.column(0)
1339+
.as_any()
1340+
.downcast_ref::<arrow_array::Int32Array>()
1341+
.unwrap();
1342+
let r_ids = combined_batch
1343+
.column(1)
1344+
.as_any()
1345+
.downcast_ref::<arrow_array::Int32Array>()
1346+
.unwrap();
1347+
1348+
for i in 0..combined_batch.num_rows() {
1349+
actual_results.push((l_ids.value(i), r_ids.value(i)));
1350+
}
1351+
actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
1352+
1353+
assert_eq!(actual_results, *expected_results);
1354+
}
1355+
1356+
Ok(())
1357+
}
11761358
}

0 commit comments

Comments
 (0)