diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index 3b96489f4..84c272d53 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -1322,19 +1322,6 @@ mod tests { Ok(spatial_join_execs) } - fn collect_nested_loop_join_exec( - plan: &Arc, - ) -> Result> { - let mut execs = Vec::new(); - plan.apply(|node| { - if let Some(exec) = node.as_any().downcast_ref::() { - execs.push(exec); - } - Ok(TreeNodeRecursion::Continue) - })?; - Ok(execs) - } - async fn test_mark_join( join_type: JoinType, options: SpatialJoinOptions, @@ -1379,7 +1366,17 @@ mod tests { ctx_no_opt.register_table("R", mem_table_right)?; let df_no_opt = ctx_no_opt.sql(sql).await?; let plan_no_opt = df_no_opt.create_physical_plan().await?; - let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?; + fn collect_nlj_exec(plan: &Arc) -> Result> { + let mut execs = Vec::new(); + plan.apply(|node| { + if let Some(exec) = node.as_any().downcast_ref::() { + execs.push(exec); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(execs) + } + let nlj_execs = collect_nlj_exec(&plan_no_opt)?; assert_eq!(nlj_execs.len(), 1); let original_nlj = nlj_execs[0]; let mark_nlj = NestedLoopJoinExec::try_new( @@ -1449,11 +1446,22 @@ mod tests { result } - fn compute_knn_ground_truth( + fn compute_knn_ground_truth_with_pair_filter( left_partitions: &[Vec], right_partitions: &[Vec], k: usize, - ) -> Vec<(i32, i32, f64)> { + keep_pair: F, + ) -> Vec<(i32, i32, f64)> + where + F: Fn(i32, i32) -> bool, + { + // NOTE: This helper mirrors our KNN semantics used in execution: + // - select top-K unfiltered candidates by distance (stable by r_id) + // - then apply a cross-side predicate to decide which pairs to keep + // (can yield < K results per probe row) + // + // The predicate is intentionally *post* top-K selection. + // (See `test_knn_join_with_filter_correctness`.) let left_data = extract_geoms_and_ids(left_partitions); let right_data = extract_geoms_and_ids(right_partitions); @@ -1468,8 +1476,11 @@ mod tests { // Sort by distance, then by ID for stability distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| a.0.cmp(&b.0))); + // KNN semantics: pick top-K unfiltered, then optionally post-filter. for (r_id, dist) in distances.iter().take(k.min(distances.len())) { - results.push((l_id, *r_id, *dist)); + if keep_pair(l_id, *r_id) { + results.push((l_id, *r_id, *dist)); + } } } @@ -1478,32 +1489,60 @@ mod tests { results } + #[rstest] #[tokio::test] - async fn test_knn_join_correctness() -> Result<()> { + async fn test_knn_join_correctness( + // TODO: Currently the underlying geo-index KNN implementation has bugs working with non-point + // geometries, so this test is restricted to point_only = true. Once + // https://github.com/georust/geo-index/pull/151 (fixing non-point KNN support) is + // released, add #[values(true, false)] here to also exercise non-point data. + #[values(true)] point_only: bool, + #[values(1, 2, 3, 4)] num_partitions: usize, + #[values(10, 30, 1000)] max_batch_size: usize, + ) -> Result<()> { // Generate slightly larger data - let ((left_schema, left_partitions), (right_schema, right_partitions)) = - create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?; + let ((left_schema, left_partitions), (right_schema, right_partitions)) = if point_only { + create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)? + } else { + create_default_test_data()? + }; - let options = SpatialJoinOptions::default(); - let k = 3; + // Use single partition to verify algorithm correctness first, avoiding partitioning issues + let options = SpatialJoinOptions { + debug: SpatialJoinDebugOptions { + num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(num_partitions), + ..Default::default() + }, + ..Default::default() + }; + let k = 6; let sql1 = format!( "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id", k ); - let expected1 = compute_knn_ground_truth(&left_partitions, &right_partitions, k) - .into_iter() - .map(|(l, r, _)| (l, r)) - .collect::>(); - + let expected1 = compute_knn_ground_truth_with_pair_filter( + &left_partitions, + &right_partitions, + k, + |_l_id, _r_id| true, + ) + .into_iter() + .map(|(l, r, _)| (l, r)) + .collect::>(); let sql2 = format!( "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id", k ); - let expected2 = compute_knn_ground_truth(&right_partitions, &left_partitions, k) - .into_iter() - .map(|(l, r, _)| (l, r)) - .collect::>(); + let expected2 = compute_knn_ground_truth_with_pair_filter( + &right_partitions, + &left_partitions, + k, + |_l_id, _r_id| true, + ) + .into_iter() + .map(|(l, r, _)| (l, r)) + .collect::>(); let sqls = [(&sql1, &expected1), (&sql2, &expected2)]; @@ -1514,7 +1553,7 @@ mod tests { left_partitions.clone(), right_partitions.clone(), Some(options.clone()), - 10, + max_batch_size, sql, ) .await?; @@ -1543,4 +1582,306 @@ mod tests { Ok(()) } + + #[rstest] + #[tokio::test] + async fn test_knn_join_with_filter_correctness( + #[values(1, 2, 3, 4)] num_partitions: usize, + #[values(10, 30, 1000)] max_batch_size: usize, + ) -> Result<()> { + let ((left_schema, left_partitions), (right_schema, right_partitions)) = + create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?; + + let options = SpatialJoinOptions { + debug: SpatialJoinDebugOptions { + num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(num_partitions), + ..Default::default() + }, + ..Default::default() + }; + + let k = 3; + let sql = format!( + "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)", + k + ); + + let batches = run_spatial_join_query( + &left_schema, + &right_schema, + left_partitions.clone(), + right_partitions.clone(), + Some(options), + max_batch_size, + &sql, + ) + .await?; + + let mut actual_results = Vec::new(); + let combined_batch = arrow::compute::concat_batches(&batches.schema(), &[batches])?; + let l_ids = combined_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let r_ids = combined_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..combined_batch.num_rows() { + actual_results.push((l_ids.value(i), r_ids.value(i))); + } + actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1))); + + // Prove the test actually exercises the "< K rows after filtering" case. + // Build a list of all probe-side IDs and count how many results each has. + let all_left_ids: Vec = extract_geoms_and_ids(&left_partitions) + .into_iter() + .map(|(id, _)| id) + .collect(); + let mut per_left_counts: std::collections::HashMap = + std::collections::HashMap::new(); + for (l_id, _) in &actual_results { + *per_left_counts.entry(*l_id).or_default() += 1; + } + let min_count = all_left_ids + .iter() + .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0)) + .min() + .unwrap_or(0); + assert!( + min_count < k, + "expected at least one probe row to produce < K rows after filtering; min_count={min_count}, k={k}" + ); + + let expected_results = compute_knn_ground_truth_with_pair_filter( + &left_partitions, + &right_partitions, + k, + |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)), + ) + .into_iter() + .map(|(l, r, _)| (l, r)) + .collect::>(); + + assert_eq!(actual_results, expected_results); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn test_knn_join_include_tie_breakers( + #[values(1, 2, 3, 4)] num_partitions: usize, + #[values(10, 100)] max_batch_size: usize, + ) -> Result<()> { + // Construct a larger dataset with *guaranteed* exact ties at the kth distance. + // + // For each probe point at (10*i, 0), we create two candidate points at (10*i-1, 0) + // and (10*i+1, 0). Those two candidates are tied (distance = 1). + // A third candidate at (10*i+2, 0) ensures there are also non-tied options. + // Spacing by 10 keeps other probes' candidates far enough away that they never interfere. + // + // With k=1: + // - knn_include_tie_breakers=false should return exactly 1 match per probe row. + // - knn_include_tie_breakers=true should return 2 matches per probe row (both ties). + // + // The exact choice of which tied row is returned when tie-breakers are disabled is not + // asserted (it is allowed to be either tied candidate). + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("wkt", DataType::Utf8, false), + ])); + + let num_probe_rows: i32 = 120; + let k = 1; + + let input_batches_left = 6; + let input_batches_right = 6; + + fn make_batches( + schema: SchemaRef, + ids: Vec, + wkts: Vec, + num_batches: usize, + ) -> Result> { + assert_eq!(ids.len(), wkts.len()); + let total = ids.len(); + let chunk = total.div_ceil(num_batches); + + let mut batches = Vec::new(); + for b in 0..num_batches { + let start = b * chunk; + if start >= total { + break; + } + let end = ((b + 1) * chunk).min(total); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow_array::Int32Array::from(ids[start..end].to_vec())), + Arc::new(arrow_array::StringArray::from( + wkts[start..end] + .iter() + .map(|s| s.as_str()) + .collect::>(), + )), + ], + )?; + batches.push(batch); + } + Ok(batches) + } + + let mut left_ids = Vec::with_capacity(num_probe_rows as usize); + let mut left_wkts = Vec::with_capacity(num_probe_rows as usize); + + let mut right_ids = Vec::with_capacity((num_probe_rows as usize) * 3); + let mut right_wkts = Vec::with_capacity((num_probe_rows as usize) * 3); + + for i in 0..num_probe_rows { + let cx = (i as i64) * 10; + left_ids.push(i); + left_wkts.push(format!("POINT ({cx} 0)")); + + // Two tied candidates at distance 1. + let base = i * 10; + right_ids.push(base + 1); + right_wkts.push(format!("POINT ({x} 0)", x = cx - 1)); + + right_ids.push(base + 2); + right_wkts.push(format!("POINT ({x} 0)", x = cx + 1)); + + // One non-tied candidate. + right_ids.push(base + 3); + right_wkts.push(format!("POINT ({x} 0)", x = cx + 2)); + } + + let left_batches = make_batches(schema.clone(), left_ids, left_wkts, input_batches_left)?; + let right_batches = + make_batches(schema.clone(), right_ids, right_wkts, input_batches_right)?; + + // Put each side into a single MemTable partition, but with multiple batches. + // This ensures the build/probe collectors see 4–8 batches and the round-robin batch + // partitioner has something to distribute. + let left_partitions = vec![left_batches]; + let right_partitions = vec![right_batches]; + + let sql = format!( + "SELECT L.id AS l_id, R.id AS r_id \ + FROM L JOIN R \ + ON ST_KNN(ST_GeomFromWKT(L.wkt), ST_GeomFromWKT(R.wkt), {k}, false)" + ); + + let base_options = SpatialJoinOptions { + debug: SpatialJoinDebugOptions { + num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(num_partitions), + ..Default::default() + }, + ..Default::default() + }; + + // Without tie-breakers: exactly 1 match per probe row. + let out_no_ties = run_spatial_join_query( + &schema, + &schema, + left_partitions.clone(), + right_partitions.clone(), + Some(SpatialJoinOptions { + knn_include_tie_breakers: false, + ..base_options.clone() + }), + max_batch_size, + &sql, + ) + .await?; + let combined = arrow::compute::concat_batches(&out_no_ties.schema(), &[out_no_ties])?; + + let l_ids = combined + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let r_ids = combined + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut per_left: std::collections::HashMap> = + std::collections::HashMap::new(); + for i in 0..combined.num_rows() { + per_left + .entry(l_ids.value(i)) + .or_default() + .push(r_ids.value(i)); + } + + assert_eq!(per_left.len() as i32, num_probe_rows); + for l_id in 0..num_probe_rows { + let r_list = per_left.get(&l_id).unwrap(); + assert_eq!( + r_list.len(), + 1, + "expected exactly 1 match for l_id={l_id} when tie-breakers are disabled" + ); + let base = l_id * 10; + let r_id = r_list[0]; + assert!( + r_id == base + 1 || r_id == base + 2, + "expected a tied nearest neighbor for l_id={l_id}, got r_id={r_id}" + ); + } + + // With tie-breakers: exactly 2 matches per probe row (both tied candidates). + let out_with_ties = run_spatial_join_query( + &schema, + &schema, + left_partitions.clone(), + right_partitions.clone(), + Some(SpatialJoinOptions { + knn_include_tie_breakers: true, + ..base_options + }), + max_batch_size, + &sql, + ) + .await?; + let combined = arrow::compute::concat_batches(&out_with_ties.schema(), &[out_with_ties])?; + let l_ids = combined + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let r_ids = combined + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut per_left: std::collections::HashMap> = + std::collections::HashMap::new(); + for i in 0..combined.num_rows() { + per_left + .entry(l_ids.value(i)) + .or_default() + .push(r_ids.value(i)); + } + assert_eq!(per_left.len() as i32, num_probe_rows); + for l_id in 0..num_probe_rows { + let mut r_list = per_left.get(&l_id).unwrap().clone(); + r_list.sort(); + let base = l_id * 10; + assert_eq!( + r_list, + vec![base + 1, base + 2], + "expected both tied nearest neighbors for l_id={l_id}" + ); + } + + Ok(()) + } } diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs b/rust/sedona-spatial-join/src/index/spatial_index.rs index bff7895df..de1213dea 100644 --- a/rust/sedona-spatial-join/src/index/spatial_index.rs +++ b/rust/sedona-spatial-join/src/index/spatial_index.rs @@ -54,6 +54,8 @@ use crate::{ use arrow::array::BooleanBufferBuilder; use sedona_common::{option::SpatialJoinOptions, sedona_internal_err, ExecutionMode}; +pub const DISTANCE_TOLERANCE: f64 = 1e-9; + pub struct SpatialIndex { pub(crate) schema: SchemaRef, pub(crate) options: SpatialJoinOptions, @@ -213,6 +215,7 @@ impl SpatialIndex { /// # Returns /// /// * `JoinResultMetrics` containing the number of actual matches and candidates processed + #[allow(unused)] pub(crate) fn query_knn( &self, probe_wkb: &Wkb, @@ -220,6 +223,25 @@ impl SpatialIndex { use_spheroid: bool, include_tie_breakers: bool, build_batch_positions: &mut Vec<(i32, i32)>, + ) -> Result { + self.query_knn_with_distance( + probe_wkb, + k, + use_spheroid, + include_tie_breakers, + build_batch_positions, + None, + ) + } + + pub(crate) fn query_knn_with_distance( + &self, + probe_wkb: &Wkb, + k: u32, + use_spheroid: bool, + include_tie_breakers: bool, + build_batch_positions: &mut Vec<(i32, i32)>, + mut distances: Option<&mut Vec>, ) -> Result { if k == 0 { return Ok(QueryResultMetrics { @@ -336,7 +358,7 @@ impl SpatialIndex { max_y + distance_f32, ); - // Use rtree.search() with envelope bounds (like the old code) + // Use rtree.search() with envelope bounds let expanded_results = self.rtree.search(min_x, min_y, max_x, max_y); candidate_count = expanded_results.len(); @@ -362,7 +384,6 @@ impl SpatialIndex { .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); // Include all results up to and including those with the same distance as the k-th result - const DISTANCE_TOLERANCE: f64 = 1e-9; let mut tie_breaker_results: Vec = Vec::new(); for (i, &(distance, result_idx)) in all_distances_with_indices.iter().enumerate() { @@ -391,6 +412,17 @@ impl SpatialIndex { for &result_idx in &final_results { if (result_idx as usize) < self.data_id_to_batch_pos.len() { build_batch_positions.push(self.data_id_to_batch_pos[result_idx as usize]); + + if let Some(dists) = distances.as_mut() { + let mut dist = f64::NAN; + if let Some(item_geom) = geometry_accessor.get_geometry(result_idx as usize) { + dist = distance_metric + .distance_to_geometry(&probe_geom, item_geom) + .to_f64() + .unwrap_or(f64::NAN); + } + dists.push(dist); + } } } diff --git a/rust/sedona-spatial-join/src/partitioning.rs b/rust/sedona-spatial-join/src/partitioning.rs index fe495f6bd..60028974f 100644 --- a/rust/sedona-spatial-join/src/partitioning.rs +++ b/rust/sedona-spatial-join/src/partitioning.rs @@ -18,9 +18,11 @@ use datafusion_common::Result; use sedona_geometry::bounding_box::BoundingBox; +pub mod broadcast; pub mod flat; pub mod kdb; pub(crate) mod partition_slots; +pub mod round_robin; pub mod rtree; pub mod stream_repartitioner; pub(crate) mod util; diff --git a/rust/sedona-spatial-join/src/partitioning/broadcast.rs b/rust/sedona-spatial-join/src/partitioning/broadcast.rs new file mode 100644 index 000000000..6f767c8bf --- /dev/null +++ b/rust/sedona-spatial-join/src/partitioning/broadcast.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::Result; +use sedona_common::sedona_internal_err; +use sedona_geometry::bounding_box::BoundingBox; + +use crate::partitioning::{SpatialPartition, SpatialPartitioner}; + +/// A partitioner that assigns everything to the Multi partition. +/// +/// This partitioner is useful when we want to broadcast the data to all partitions. +pub struct BroadcastPartitioner { + num_partitions: usize, +} + +impl BroadcastPartitioner { + pub fn new(num_partitions: usize) -> Self { + Self { num_partitions } + } +} + +impl SpatialPartitioner for BroadcastPartitioner { + fn num_regular_partitions(&self) -> usize { + self.num_partitions + } + + fn partition(&self, _bbox: &BoundingBox) -> Result { + Ok(SpatialPartition::Multi) + } + + fn partition_no_multi(&self, _bbox: &BoundingBox) -> Result { + sedona_internal_err!("BroadcastPartitioner does not support partition_no_multi") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_broadcast_partitioner() { + let num_partitions = 4; + let partitioner = BroadcastPartitioner::new(num_partitions); + assert_eq!(partitioner.num_regular_partitions(), num_partitions); + + let bbox = BoundingBox::xy((0.0, 10.0), (0.0, 10.0)); + + // Test partition + assert_eq!( + partitioner.partition(&bbox).unwrap(), + SpatialPartition::Multi + ); + + // Test partition_no_multi + assert!(partitioner.partition_no_multi(&bbox).is_err()); + } +} diff --git a/rust/sedona-spatial-join/src/partitioning/round_robin.rs b/rust/sedona-spatial-join/src/partitioning/round_robin.rs new file mode 100644 index 000000000..a5d731170 --- /dev/null +++ b/rust/sedona-spatial-join/src/partitioning/round_robin.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use datafusion_common::Result; +use sedona_geometry::bounding_box::BoundingBox; + +use crate::partitioning::{SpatialPartition, SpatialPartitioner}; + +/// A partitioner that assigns partitions in a round-robin fashion. +/// +/// This partitioner is used for KNN join, where the build side is partitioned +/// into `num_partitions` partitions, and the probe side is assigned to the +/// `Multi` partition (i.e., broadcast to all partitions). +pub struct RoundRobinPartitioner { + num_partitions: usize, + counter: AtomicUsize, +} + +impl RoundRobinPartitioner { + pub fn new(num_partitions: usize) -> Self { + Self { + num_partitions, + counter: AtomicUsize::new(0), + } + } +} + +impl SpatialPartitioner for RoundRobinPartitioner { + fn num_regular_partitions(&self) -> usize { + self.num_partitions + } + + fn partition(&self, bbox: &BoundingBox) -> Result { + self.partition_no_multi(bbox) + } + + fn partition_no_multi(&self, _bbox: &BoundingBox) -> Result { + let idx = self.counter.fetch_add(1, Ordering::Relaxed); + Ok(SpatialPartition::Regular( + (idx % self.num_partitions) as u32, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_robin_partitioner() { + let num_partitions = 4; + let partitioner = RoundRobinPartitioner::new(num_partitions); + assert_eq!(partitioner.num_regular_partitions(), num_partitions); + + let bbox = BoundingBox::xy((0.0, 10.0), (0.0, 10.0)); + + for i in 0..10 { + assert_eq!( + partitioner.partition_no_multi(&bbox).unwrap(), + SpatialPartition::Regular((i % num_partitions) as u32) + ); + } + } +} diff --git a/rust/sedona-spatial-join/src/prepare.rs b/rust/sedona-spatial-join/src/prepare.rs index 7198c2e5b..309eca12b 100644 --- a/rust/sedona-spatial-join/src/prepare.rs +++ b/rust/sedona-spatial-join/src/prepare.rs @@ -40,8 +40,10 @@ use crate::{ SpatialJoinBuildMetrics, }, partitioning::{ + broadcast::BroadcastPartitioner, flat::FlatPartitioner, kdb::KDBPartitioner, + round_robin::RoundRobinPartitioner, stream_repartitioner::{SpilledPartition, SpilledPartitions, StreamRepartitioner}, PartitionedSide, SpatialPartition, SpatialPartitioner, }, @@ -243,14 +245,16 @@ impl SpatialJoinComponentsBuilder { build_partitions: &mut Vec, seed: u64, ) -> Result> { - if matches!( + let build_partitioner: Arc = if matches!( self.spatial_predicate, - SpatialPredicate::KNearestNeighbors(..) + SpatialPredicate::KNearestNeighbors(_) ) { - return sedona_internal_err!("Partitioned KNN join is not supported yet"); - } - - let build_partitioner: Arc = { + // Spatial partitioning does not work well for KNN joins, so we simply use round-robin + // partitioning to spread the indexed data evenly to make each index fit in memory, and + // the probe side will be broadcasted to all partitions by partitioning all of them to + // the Multi partition. + Arc::new(RoundRobinPartitioner::new(num_partitions)) + } else { // Use spatial partitioners to partition the build side and the probe side, this will // reduce the amount of work needed for probing each partitioned index. // The KDB partitioner is built using the collected bounding box samples. @@ -299,7 +303,12 @@ impl SpatialJoinComponentsBuilder { num_partitions: usize, merged_spilled_partitions: &SpilledPartitions, ) -> Result> { - let probe_partitioner: Arc = { + let probe_partitioner: Arc = if matches!( + self.spatial_predicate, + SpatialPredicate::KNearestNeighbors(_) + ) { + Arc::new(BroadcastPartitioner::new(num_partitions)) + } else { // Build a flat partitioner using these partitions let mut partition_bounds = Vec::with_capacity(num_partitions); for k in 0..num_partitions { diff --git a/rust/sedona-spatial-join/src/probe.rs b/rust/sedona-spatial-join/src/probe.rs index 6f749b8f6..82290095e 100644 --- a/rust/sedona-spatial-join/src/probe.rs +++ b/rust/sedona-spatial-join/src/probe.rs @@ -41,5 +41,6 @@ impl ProbeStreamMetrics { } pub(crate) mod first_pass_stream; +pub(crate) mod knn_results_merger; pub(crate) mod non_partitioned_stream; pub(crate) mod partitioned_stream_provider; diff --git a/rust/sedona-spatial-join/src/probe/knn_results_merger.rs b/rust/sedona-spatial-join/src/probe/knn_results_merger.rs new file mode 100644 index 000000000..978f7f118 --- /dev/null +++ b/rust/sedona-spatial-join/src/probe/knn_results_merger.rs @@ -0,0 +1,2226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use core::f64; +use std::ops::Range; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayBuilder, AsArray, Float64Array, ListArray, OffsetBufferBuilder, PrimitiveBuilder, + RecordBatch, StructArray, UInt64Array, +}; +use arrow::buffer::OffsetBuffer; +use arrow::compute::{concat, concat_batches, interleave}; +use arrow::datatypes::{DataType, Field, Float64Type, Schema, SchemaRef, UInt64Type}; +use arrow_array::ArrayRef; +use arrow_schema::Fields; +use datafusion::config::SpillCompression; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_plan::metrics::SpillMetrics; +use sedona_common::sedona_internal_err; + +use crate::index::spatial_index::DISTANCE_TOLERANCE; +use crate::utils::spill::{RecordBatchSpillReader, RecordBatchSpillWriter}; + +/// [UnprocessedKNNResultBatch] represents the KNN results produced by probing the spatial index. +/// An [UnprocessedKNNResultBatch] may include KNN results for multiple probe rows. +/// +/// The KNN results are stored in a StructArray, where each row corresponds to a KNN result. +/// The results for the same probe row are stored in contiguous rows, and the offsets to +/// split the results into groups per probe row are stored in the `offsets` field. +/// +/// Each probe row has a unique index. The index must be strictly increasing +/// across probe rows. The sequence of index across the entire sequence of ingested +/// [UnprocessedKNNResultBatch] must also be strictly increasing. The index is computed based on +/// the 0-based index of the probe row in this probe partition. +/// +/// The KNN results are filtered, meaning that the original KNN results obtained by probing +/// the spatial index may be further filtered based on some predicates. It is also possible that +/// all the KNN results for a probe row are filtered out. However, we still need to keep track of the +/// distances of unfiltered results to correctly compute the top-K distances before filtering. This +/// is critical for correctly merging KNN results from multiple partitions. +/// +/// Imagine that a KNN query for a probe row yields the following 5 results (K = 5): +/// +/// ```text +/// D0 D1 D2 D3 D4 +/// R0 R2 R3 +/// ``` +/// +/// Where Di is the distance of the i-th nearest neighbor, and Ri is the result row index. +/// R1 and R4 are filtered out based on some predicate, so the final results only contain R0, R2, and R3. +/// The core idea is that the filtering is applied AFTER determining the top-K distances, so the number +/// of final results may be less than K. +/// +/// However, if we split the object side of KNN join into 2 partitions, and the KNN results from +/// each partition are as follows: +/// +/// ```text +/// Partition 0: +/// D1 D3 D5 D6 D7 +/// R3 R6 R7 +/// +/// Partition 1: +/// D0 D2 D4 D8 D9 +/// R0 R2 R8 +/// ``` +/// +/// If we blindly merge the filtered results from both partitions and take top-k, we would get: +/// +/// ```text +/// D0 D2 D3 D6 D8 +/// R0 R2 R3 R6 R8 +/// ``` +/// +/// Which contains more results than single-partitioned KNN join (i.e., 5 results instead of 3). This is +/// incorrect. +/// +/// When merging the results from both partitions, we need to consider the distances of all unfiltered +/// results to correctly determine the top-K distances before filtering. In this case, the top-5 distances +/// are D0, D1, D2, D3, and D4. We take D4 as the distance threshold to filter merged results. After filtering, +/// we still get R0, R2, and R3 as the final results. +/// +/// Please note that the KNN results for the last probe row in this array may be incomplete, +/// this is due to batch slicing during probe result batch production. We should be cautious +/// and correctly handle the KNN results for each probe row across multiple slices. +/// +/// Here is a concrete example: the [UnprocessedKNNResultBatch] may contain KNN results for 3 probe rows: +/// +/// ```text +/// [P0, R00] +/// [P0, R01] +/// [P0, R02] +/// [P1, R10] +/// [P1, R11] +/// [P1, R12] +/// [P2, R20] +/// ``` +/// +/// Where Pi is the i-th probe row, and Rij is the j-th KNN result for probe row Pi. +/// The KNN results for probe row P2 could be incomplete, and the next ingested KNN result batch +/// may contain more results for probe row P2: +/// +/// ```text +/// [P2, R21] +/// [P2, R22] +/// [P3, R30] +/// ... +/// ``` +/// +/// In practice, we process the KNN results or a probe row only when we have seen all its results. +/// The may-be incomplete tail part of an ingested [UnprocessedKNNResultBatch] is sliced and concatenated with +/// the next ingested [UnprocessedKNNResultBatch] to form a complete set of KNN results for that probe row. +/// This slicing and concatenating won't happen frequently in practice (once per ingested batch +/// on average), so the performance impact is minimal. +struct UnprocessedKNNResultBatch { + row_array: StructArray, + probe_indices: Vec, + distances: Vec, + unfiltered_probe_indices: Vec, + unfiltered_distances: Vec, +} + +impl UnprocessedKNNResultBatch { + fn new( + row_array: StructArray, + probe_indices: Vec, + distances: Vec, + unfiltered_probe_indices: Vec, + unfiltered_distances: Vec, + ) -> Self { + Self { + row_array, + probe_indices, + distances, + unfiltered_probe_indices, + unfiltered_distances, + } + } + + /// Create a new [UnprocessedKNNResultBatch] representing the unprocessed tail KNN results + /// from an unprocessed [KNNProbeResult]. + fn new_unprocessed_tail(tail: KNNProbeResult<'_>, row_array: &StructArray) -> Self { + let index = tail.probe_row_index; + let num_rows = tail.row_range.len(); + let num_unfiltered_rows = tail.unfiltered_distances.len(); + + let sliced_row_array = row_array.slice(tail.row_range.start, num_rows); + let probe_indices = vec![index; num_rows]; + let distances = tail.distances.to_vec(); + let unfiltered_probe_indices = vec![index; num_unfiltered_rows]; + let unfiltered_distances = tail.unfiltered_distances.to_vec(); + + Self { + row_array: sliced_row_array, + probe_indices, + distances, + unfiltered_probe_indices, + unfiltered_distances, + } + } + + /// Merge the current [UnprocessedKNNResultBatch] with another one, producing a new + /// [UnprocessedKNNResultBatch]. + fn merge(self, other: Self) -> Result { + let concat_array = + concat(&[&self.row_array, &other.row_array]).map_err(|e| arrow_datafusion_err!(e))?; + let mut probe_indices = self.probe_indices; + probe_indices.extend(other.probe_indices); + let mut distances = self.distances; + distances.extend(other.distances); + let mut unfiltered_probe_indices = self.unfiltered_probe_indices; + unfiltered_probe_indices.extend(other.unfiltered_probe_indices); + let mut unfiltered_distances = self.unfiltered_distances; + unfiltered_distances.extend(other.unfiltered_distances); + + Ok(Self { + row_array: concat_array.as_struct().clone(), + probe_indices, + distances, + unfiltered_probe_indices, + unfiltered_distances, + }) + } +} + +/// Reorganize [UnprocessedKNNResultBatch] for easier processing. The main goal is to group KNN results by +/// probe row index. There is an iterator implementation [KNNProbeResultIterator] that yields +/// [KNNProbeResult] for each probe row in order. +struct KNNResultArray { + /// The KNN result batches produced by probing the spatial index with a probe batch + array: StructArray, + /// Distance for each KNN result row + distances: Vec, + /// Index for each probe row, this must be strictly increasing. + indices: Vec, + /// Offsets to split the batches into groups per probe row. It is always of length + /// `indices.len() + 1`. + offsets: Vec, + /// Indices for each unfiltered probe row, This is a superset of `indices`. + /// This must be strictly increasing. + unfiltered_indices: Vec, + /// Distances for each unfiltered KNN result row. This is a superset of `distances`. + unfiltered_distances: Vec, + /// Offsets to split the unfiltered distances into groups per probe row. It is always of length + /// `unfiltered_indices.len() + 1`. + unfiltered_offsets: Vec, +} + +impl KNNResultArray { + fn new(unprocessed_batch: UnprocessedKNNResultBatch) -> Self { + let UnprocessedKNNResultBatch { + row_array, + probe_indices, + distances, + unfiltered_probe_indices, + unfiltered_distances, + .. + } = unprocessed_batch; + + assert_eq!(row_array.len(), probe_indices.len()); + assert_eq!(probe_indices.len(), distances.len()); + assert_eq!(unfiltered_probe_indices.len(), unfiltered_distances.len()); + assert!(probe_indices.len() <= unfiltered_probe_indices.len()); + + let compute_range_encoding = |mut indices: Vec| { + let mut offsets = Vec::with_capacity(indices.len() + 1); + offsets.push(0); + if indices.is_empty() { + return (offsets, Vec::new()); + } + + let mut prev = indices[0]; + let mut pos = 1; + for i in 1..indices.len() { + if indices[i] != prev { + assert!(indices[i] > prev, "indices must be non-decreasing"); + offsets.push(i); + indices[pos] = indices[i]; + pos += 1; + } + prev = indices[i]; + } + offsets.push(indices.len()); + indices.truncate(pos); + (offsets, indices) + }; + + let (offsets, indices) = compute_range_encoding(probe_indices); + let (unfiltered_offsets, unfiltered_indices) = + compute_range_encoding(unfiltered_probe_indices); + + // The iterator implementation relies on `indices` being an in-order subsequence + // of `unfiltered_indices`. + debug_assert!({ + let mut j = 0; + let mut ok = true; + for &g in &indices { + while j < unfiltered_indices.len() && unfiltered_indices[j] < g { + j += 1; + } + if j >= unfiltered_indices.len() || unfiltered_indices[j] != g { + ok = false; + break; + } + } + ok + }); + Self { + array: row_array, + distances, + indices, + offsets, + unfiltered_indices, + unfiltered_distances, + unfiltered_offsets, + } + } +} + +/// KNNProbeResult represents a unified view for the KNN results for a single probe row. +/// The KNN results can be from a spilled batch or an ingested batch. This intermediate +/// data structure is for working with both spilled and ingested KNN results uniformly. +/// +/// KNNProbeResult can also be used to represent KNN results for a probe row that has +/// no filtered results. In this case, the `row_range` will be an empty range, and the +/// `distances` will be an empty slice. +struct KNNProbeResult<'a> { + /// Index of the probe row + probe_row_index: usize, + /// Range of KNN result rows in the implicitly referenced StructArray. The referenced + /// StructArray only contains filtered results. + row_range: Range, + /// Distances for each KNN result row + distances: &'a [f64], + /// Distances for each unfiltered result row. Some of the results were filtered so they + /// do not appear in the StructArray, but we still need the distances of all unfiltered + /// results to correctly compute the top-K distances before the filtering. + unfiltered_distances: &'a [f64], +} + +impl<'a> KNNProbeResult<'a> { + fn new( + probe_row_index: usize, + row_range: Range, + distances: &'a [f64], + unfiltered_distances: &'a [f64], + ) -> Self { + assert_eq!(row_range.len(), distances.len()); + // Please note that we don't have `unfiltered_distances.len() >= distances.len()` here. + // We may have ties in `distances`, which may exceed K even after filtering. + // `unfiltered_distances` does not include distances that are tied with the K-th distance. + Self { + probe_row_index, + row_range, + distances, + unfiltered_distances, + } + } +} + +/// Iterator over [KNNProbeResult] in a [KNNResultArray] +struct KNNProbeResultIterator<'a> { + array: &'a KNNResultArray, + unfiltered_pos: usize, + pos: usize, +} + +impl KNNProbeResultIterator<'_> { + fn new(array: &KNNResultArray) -> KNNProbeResultIterator<'_> { + KNNProbeResultIterator { + array, + unfiltered_pos: 0, + pos: 0, + } + } +} + +impl<'a> Iterator for KNNProbeResultIterator<'a> { + type Item = KNNProbeResult<'a>; + + /// This iterator yields KNNProbeResult for each probe row in the [KNNResultArray]. + /// Given that the [KNNResultArray::indices] is strictly increasing, + /// The [KNNProbeResult] it yields has strictly increasing [KNNProbeResult::probe_row_index]. + fn next(&mut self) -> Option { + if self.unfiltered_pos >= self.array.unfiltered_indices.len() { + return None; + } + + let unfiltered_start = self.array.unfiltered_offsets[self.unfiltered_pos]; + let unfiltered_end = self.array.unfiltered_offsets[self.unfiltered_pos + 1]; + let unfiltered_index = self.array.unfiltered_indices[self.unfiltered_pos]; + + let start = self.array.offsets[self.pos]; + let index = if self.pos >= self.array.indices.len() { + // All filtered results have been consumed. + usize::MAX + } else { + self.array.indices[self.pos] + }; + + assert!(index >= unfiltered_index); + + let result = if index == unfiltered_index { + // This probe row has filtered results + let end = self.array.offsets[self.pos + 1]; + let row_range = start..end; + let distances = &self.array.distances[start..end]; + let unfiltered_distances = + &self.array.unfiltered_distances[unfiltered_start..unfiltered_end]; + self.pos += 1; + KNNProbeResult::new(index, row_range, distances, unfiltered_distances) + } else { + // This probe row has no filtered results + KNNProbeResult::new( + unfiltered_index, + start..start, + &[], + &self.array.unfiltered_distances[unfiltered_start..unfiltered_end], + ) + }; + + self.unfiltered_pos += 1; + Some(result) + } +} + +/// Access arrays in a spilled KNN result batch. Provides easy access to KNN results of +/// probe rows as [KNNProbeResult]. +struct SpilledBatchArrays { + indices: SpilledBatchIndexArray, + distances: Float64Array, + unfiltered_distances: Float64Array, + offsets: OffsetBuffer, + unfiltered_offsets: OffsetBuffer, + rows: StructArray, +} + +impl SpilledBatchArrays { + fn new(batch: &RecordBatch) -> Self { + let index_col = batch + .column(0) + .as_primitive::(); + + let unfiltered_dist_list_array = batch.column(2).as_list::(); + let unfiltered_offset = unfiltered_dist_list_array.offsets(); + let unfiltered_distances = unfiltered_dist_list_array + .values() + .as_primitive::(); + + let row_and_dist_list_array = batch.column(1).as_list::(); + let offsets = row_and_dist_list_array.offsets(); + let row_and_dist_array = row_and_dist_list_array.values().as_struct(); + let dist_array = row_and_dist_array.column(1).as_primitive::(); + + let rows = row_and_dist_array.column(0).as_struct(); + + Self { + indices: SpilledBatchIndexArray::new(index_col.clone()), + distances: dist_array.clone(), + unfiltered_distances: unfiltered_distances.clone(), + offsets: offsets.clone(), + unfiltered_offsets: unfiltered_offset.clone(), + rows: rows.clone(), + } + } + + /// Get [KNNProbeResult] for the given probe row index inside the spilled batch. + /// The `row_idx` must be within the range of indices in this spilled batch. + fn get_probe_result(&self, row_idx: usize) -> KNNProbeResult<'_> { + let indices = self.indices.array.values().as_ref(); + let unfiltered_offsets = self.unfiltered_offsets.as_ref(); + let unfiltered_start = unfiltered_offsets[row_idx] as usize; + let unfiltered_end = unfiltered_offsets[row_idx + 1] as usize; + let unfiltered_distances = self.unfiltered_distances.values().as_ref(); + let offsets = self.offsets.as_ref(); + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + let distances = self.distances.values().as_ref(); + KNNProbeResult::new( + indices[row_idx] as usize, + start..end, + &distances[start..end], + &unfiltered_distances[unfiltered_start..unfiltered_end], + ) + } +} + +/// Index array with a cursor for keeping track of the progress of iterating over a +/// spilled batch. +struct SpilledBatchIndexArray { + array: UInt64Array, + pos: usize, +} + +struct AdvanceToResult { + skipped_range: Range, + found_target: HasFoundIndex, +} + +enum HasFoundIndex { + Found, + NotFound { should_load_next_batch: bool }, +} + +impl SpilledBatchIndexArray { + fn new(array: UInt64Array) -> Self { + // Values in the index array should be strictly increasing. + let values = array.values().as_ref(); + for i in 1..values.len() { + assert!(values[i] > values[i - 1]); + } + + Self { array, pos: 0 } + } + + /// Advance the cursor to target index. The `target` is expected to be monotonically increasing + /// across calls. We still tolerate the case where `target` is smaller than the current position, + /// in which case we simply return [HasFoundIndex::NotFound]. + /// + /// Please note that once a `target` is found, the cursor is advanced to the next position. + /// Advancing to the same `target` again will yield [HasFoundIndex::NotFound]. + fn advance_to(&mut self, target: usize) -> AdvanceToResult { + let values = self.array.values().as_ref(); + let begin_pos = self.pos; + + // Directly jump to the end if target is larger than the last value, and signal the + // caller that we should load the next batch. + if values.last().is_none_or(|last| (*last as usize) < target) { + self.pos = values.len(); + return AdvanceToResult { + skipped_range: begin_pos..self.pos, + found_target: HasFoundIndex::NotFound { + should_load_next_batch: true, + }, + }; + } + + // Iterate over the array from current position, until we hit or exceed target. + while self.pos < values.len() { + let value = values[self.pos] as usize; + if value <= target { + self.pos += 1; + if value == target { + return AdvanceToResult { + skipped_range: begin_pos..self.pos, + found_target: HasFoundIndex::Found, + }; + } + } else { + return AdvanceToResult { + skipped_range: begin_pos..self.pos, + found_target: HasFoundIndex::NotFound { + should_load_next_batch: false, + }, + }; + } + } + + // Reached the end without finding target. + AdvanceToResult { + skipped_range: begin_pos..self.pos, + found_target: HasFoundIndex::NotFound { + should_load_next_batch: false, + }, + } + } +} + +/// KNNResultsMerger handles the merging of KNN "nearest so far" results from multiple partitions. +/// It maintains spill files to store intermediate results. +pub struct KNNResultsMerger { + k: usize, + include_tie_breaker: bool, + /// Schema for the intermediate spill files + spill_schema: SchemaRef, + /// Runtime env + runtime_env: Arc, + /// Spill compression + spill_compression: SpillCompression, + /// Spill metrics + spill_metrics: SpillMetrics, + /// Internal state + state: MergerState, +} + +struct MergerState { + /// File containing results from previous (0..N-1) partitions + previous_file: Option, + /// Reader for previous file + previous_reader: Option, + /// Spill writer for current (0..N) partitions + current_writer: Option, + /// Spilled batches loaded from previous file + spilled_batches: Vec, + /// Builder for merged KNN result batches or spilled batches + batch_builder: KNNResultBatchBuilder, + /// Unprocessed tail KNN results from the last ingested batch + unprocessed_tail: Option, +} + +impl KNNResultsMerger { + pub fn try_new( + k: usize, + include_tie_breaker: bool, + target_batch_size: usize, + runtime_env: Arc, + spill_compression: SpillCompression, + result_schema: SchemaRef, + spill_metrics: SpillMetrics, + ) -> Result { + let spill_schema = create_spill_schema(Arc::clone(&result_schema)); + let batch_builder = + KNNResultBatchBuilder::new(Arc::clone(&result_schema), target_batch_size); + + let writer = RecordBatchSpillWriter::try_new( + runtime_env.clone(), + spill_schema.clone(), + "knn_spill", + spill_compression, + spill_metrics.clone(), + None, + )?; + Ok(Self { + k, + include_tie_breaker, + spill_schema, + runtime_env, + spill_compression, + spill_metrics, + state: MergerState { + previous_file: None, + previous_reader: None, + current_writer: Some(writer), + spilled_batches: Vec::new(), + batch_builder, + unprocessed_tail: None, + }, + }) + } + + pub fn rotate(&mut self, probing_last_index: bool) -> Result<()> { + self.state.previous_file = self + .state + .current_writer + .take() + .map(|w| w.finish()) + .transpose()?; + self.state.previous_reader = None; + assert!(self.state.unprocessed_tail.is_none()); + assert!(self.state.batch_builder.is_empty()); + self.state.spilled_batches.clear(); + + if let Some(file) = &self.state.previous_file { + self.state.previous_reader = Some(RecordBatchSpillReader::try_new(file)?); + } + + if !probing_last_index { + self.state.current_writer = Some(RecordBatchSpillWriter::try_new( + self.runtime_env.clone(), + self.spill_schema.clone(), + "knn_spill", + self.spill_compression, + self.spill_metrics.clone(), + None, + )?); + } + + Ok(()) + } + + pub fn ingest( + &mut self, + batch: RecordBatch, + probe_indices: Vec, + distances: Vec, + unfiltered_probe_indices: Vec, + unfiltered_distances: Vec, + ) -> Result> { + let row_array = StructArray::from(batch); + let ingested_batch = UnprocessedKNNResultBatch::new( + row_array, + probe_indices, + distances, + unfiltered_probe_indices, + unfiltered_distances, + ); + let unprocessed_batch = if let Some(tail) = self.state.unprocessed_tail.take() { + tail.merge(ingested_batch)? + } else { + ingested_batch + }; + + let knn_result_array = KNNResultArray::new(unprocessed_batch); + let knn_query_result_iterator = KNNProbeResultIterator::new(&knn_result_array); + + let mut prev_result_opt: Option> = None; + for result in knn_query_result_iterator { + // Only the previous result is guaranteed to be complete. + if let Some(result) = prev_result_opt { + self.merge_and_append_result(&result)?; + } + + prev_result_opt = Some(result); + } + + // Assembled this batch. Write to spill file or produce output batch. + let result_batch_opt = self.flush_merged_batch(Some(&knn_result_array))?; + + // Prepare for ingesting the next batch + if let Some(unprocessed_result) = prev_result_opt { + self.state.unprocessed_tail = Some(UnprocessedKNNResultBatch::new_unprocessed_tail( + unprocessed_result, + &knn_result_array.array, + )); + } + + Ok(result_batch_opt) + } + + /// Flushes any pending buffered probe index at the end of a probe batch iterator. + /// + /// This is used to emit the final probe index that may have been kept buffered because + /// it could continue in the next produced slice. + /// + /// Returns `Ok(Some(batch))` at most once per pending buffered index; if there is nothing + /// pending (or results are being spilled to disk for non-final indexed partitions), returns + /// `Ok(None)`. + pub fn produce_batch_until( + &mut self, + end_index_exclusive: usize, + ) -> Result> { + // Consume and process any unprocessed tail from previous ingested batch + let tail_batch_opt = if let Some(tail) = self.state.unprocessed_tail.take() { + let knn_result_array = KNNResultArray::new(tail); + let knn_query_result_iterator = KNNProbeResultIterator::new(&knn_result_array); + for result in knn_query_result_iterator { + self.merge_and_append_result(&result)?; + } + self.flush_merged_batch(Some(&knn_result_array))? + } else { + None + }; + + // Load spilled batches up to end_index_exclusive, if there's any. + let spilled_batch_opt = if end_index_exclusive > 0 { + let end_target_idx = end_index_exclusive - 1; + // `end_target_idx` might have already been loaded before, but that's fine. The following operation + // will be a no-op in that case. + if let Some((batch_idx, row_idx)) = self.load_spilled_batches_up_to(end_target_idx)? { + let loaded_range = row_idx..(row_idx + 1); + self.append_spilled_results_in_range(batch_idx, &loaded_range); + } + self.flush_merged_batch(None)? + } else { + None + }; + + match (tail_batch_opt, spilled_batch_opt) { + (Some(batch), None) | (None, Some(batch)) => Ok(Some(batch)), + (None, None) => Ok(None), + (Some(tail_batch), Some(spilled_batch)) => { + let result_batch = + concat_batches(tail_batch.schema_ref(), [&tail_batch, &spilled_batch]) + .map_err(|e| arrow_datafusion_err!(e))?; + Ok(Some(result_batch)) + } + } + } + + fn merge_and_append_result(&mut self, result: &KNNProbeResult<'_>) -> Result<()> { + if let Some((spilled_batch_idx, row_idx)) = + self.load_spilled_batches_up_to(result.probe_row_index)? + { + let spilled_batch_array = &self.state.spilled_batches[spilled_batch_idx]; + let spilled_result = spilled_batch_array.get_probe_result(row_idx); + self.state.batch_builder.merge_and_append( + &spilled_result, + spilled_batch_idx, + result, + self.k, + self.include_tie_breaker, + ); + } else { + // No spilled results for this index + self.state + .batch_builder + .append(result, RowSelector::FromIngested { row_idx: 0 }); + } + Ok(()) + } + + /// Load spilled batches until we find the target index, or exhaust all spilled batches. + /// Returns the (batch_idx, row_idx) of the found target index within the spilled batches, + /// or None if the target index is not found in any spilled batch. + fn load_spilled_batches_up_to(&mut self, target_idx: usize) -> Result> { + loop { + if !self.state.spilled_batches.is_empty() { + let batch_idx = self.state.spilled_batches.len() - 1; + let spilled_batch = &mut self.state.spilled_batches[batch_idx]; + + let res = spilled_batch.indices.advance_to(target_idx); + + match res.found_target { + HasFoundIndex::Found => { + // Found within current batch + let row_idx = res.skipped_range.end - 1; + self.append_spilled_results_in_range( + batch_idx, + &(res.skipped_range.start..row_idx), + ); + return Ok(Some((batch_idx, row_idx))); + } + HasFoundIndex::NotFound { + should_load_next_batch, + } => { + self.append_spilled_results_in_range(batch_idx, &res.skipped_range); + if !should_load_next_batch { + // Not found, but no need to load the next batch + return Ok(None); + } + } + } + } + + // Load next batch + let Some(prev_reader) = self.state.previous_reader.as_mut() else { + return Ok(None); + }; + let Some(batch) = prev_reader.next_batch() else { + return Ok(None); + }; + let batch = batch?; + self.state + .spilled_batches + .push(SpilledBatchArrays::new(&batch)); + } + } + + fn append_spilled_results_in_range(&mut self, batch_idx: usize, row_range: &Range) { + let spilled_batch_array = &self.state.spilled_batches[batch_idx]; + for row_idx in row_range.clone() { + let spilled_result = spilled_batch_array.get_probe_result(row_idx); + self.state.batch_builder.append( + &spilled_result, + RowSelector::FromSpilled { + batch_idx, + row_idx: 0, + }, + ); + } + } + + fn flush_merged_batch( + &mut self, + knn_result_array: Option<&KNNResultArray>, + ) -> Result> { + let spilled_batches = self + .state + .spilled_batches + .iter() + .map(|b| &b.rows) + .collect::>(); + let ingested_array = knn_result_array.map(|a| &a.array); + let batch_opt = match &mut self.state.current_writer { + Some(writer) => { + // Write to spill file + if let Some(spilled_batch) = self + .state + .batch_builder + .build_spilled_batch(ingested_array, &spilled_batches)? + { + writer.write_batch(spilled_batch)?; + } + None + } + None => { + // Produce output batch + self.state + .batch_builder + .build_result_batch(ingested_array, &spilled_batches)? + } + }; + + // Keep only the last spilled batch, since we don't need earlier ones anymore. + let num_batches = self.state.spilled_batches.len(); + if num_batches > 1 { + self.state.spilled_batches.drain(0..num_batches - 1); + } + + Ok(batch_opt) + } +} + +/// Builders for KNN merged result batches or spilled batches. +struct KNNResultBatchBuilder { + spill_schema: SchemaRef, + rows_inner_fields: Fields, + capacity: usize, + unfiltered_dist_array_builder: PrimitiveBuilder, + unfiltered_dist_offsets_builder: OffsetBufferBuilder, + index_array_builder: PrimitiveBuilder, + dist_array_builder: PrimitiveBuilder, + row_array_offsets_builder: OffsetBufferBuilder, + rows_selector: Vec, + /// Scratch space for merging top-k distances + top_k_distances: Vec, + /// Scratch space for sorting row selectors by distance when merging KNN results + row_selector_with_distance: Vec<(RowSelector, f64)>, +} + +/// The source of a merged row in the final KNN result. It can be from either a spilled batch +/// or an ingested batch. +#[derive(Copy, Clone)] +enum RowSelector { + FromSpilled { batch_idx: usize, row_idx: usize }, + FromIngested { row_idx: usize }, +} + +impl RowSelector { + fn with_row_idx(&self, row_idx: usize) -> Self { + match self { + RowSelector::FromSpilled { batch_idx, .. } => RowSelector::FromSpilled { + batch_idx: *batch_idx, + row_idx, + }, + RowSelector::FromIngested { .. } => RowSelector::FromIngested { row_idx }, + } + } +} + +impl KNNResultBatchBuilder { + fn new(result_schema: SchemaRef, capacity: usize) -> Self { + let spill_schema = create_spill_schema(Arc::clone(&result_schema)); + let rows_inner_fields = create_rows_inner_fields(&result_schema); + let unfiltered_dist_array_builder = Float64Array::builder(capacity); + let unfiltered_dist_offsets_builder = OffsetBufferBuilder::::new(capacity); + let index_array_builder = UInt64Array::builder(capacity); + let dist_array_builder = Float64Array::builder(capacity); + let row_array_offsets_builder = OffsetBufferBuilder::::new(capacity); + + Self { + spill_schema, + rows_inner_fields, + capacity, + unfiltered_dist_array_builder, + unfiltered_dist_offsets_builder, + index_array_builder, + dist_array_builder, + row_array_offsets_builder, + rows_selector: Vec::with_capacity(capacity), + top_k_distances: Vec::new(), + row_selector_with_distance: Vec::new(), + } + } + + fn is_empty(&self) -> bool { + self.index_array_builder.is_empty() + } + + fn append(&mut self, results: &KNNProbeResult<'_>, row_selector_template: RowSelector) { + for (row_idx, dist) in results.row_range.clone().zip(results.distances.iter()) { + self.rows_selector + .push(row_selector_template.with_row_idx(row_idx)); + self.dist_array_builder.append_value(*dist); + } + + self.row_array_offsets_builder + .push_length(results.row_range.len()); + self.unfiltered_dist_array_builder + .append_slice(results.unfiltered_distances); + self.unfiltered_dist_offsets_builder + .push_length(results.unfiltered_distances.len()); + self.index_array_builder + .append_value(results.probe_row_index as u64); + } + + fn merge_and_append( + &mut self, + spilled_results: &KNNProbeResult<'_>, + spilled_batch_idx: usize, + ingested_results: &KNNProbeResult<'_>, + k: usize, + include_tie_breaker: bool, + ) { + assert_eq!( + spilled_results.probe_row_index, + ingested_results.probe_row_index + ); + + merge_unfiltered_topk( + k, + spilled_results.unfiltered_distances, + ingested_results.unfiltered_distances, + &mut self.top_k_distances, + ); + + let num_kept_rows = self.append_merged_knn_probe_results( + spilled_batch_idx, + spilled_results, + ingested_results, + k, + include_tie_breaker, + ); + + self.row_array_offsets_builder.push_length(num_kept_rows); + self.unfiltered_dist_array_builder + .append_slice(&self.top_k_distances); + self.unfiltered_dist_offsets_builder + .push_length(self.top_k_distances.len()); + self.index_array_builder + .append_value(spilled_results.probe_row_index as u64); + } + + /// Append top K row selectors and distances from `spillled_results` and `ingested_results` that are + /// within `distance_threshold`. + /// Returns the number of values inserted into the [`KNNResultArrayBuilders::rows_selector`] and + /// [`KNNResultArrayBuilders::dist_array_builder`]. + fn append_merged_knn_probe_results( + &mut self, + spilled_batch_idx: usize, + spilled_results: &KNNProbeResult<'_>, + ingested_results: &KNNProbeResult<'_>, + k: usize, + include_tie_breaker: bool, + ) -> usize { + // Sort all distances from both spilled and ingested results + let row_dists = &mut self.row_selector_with_distance; + row_dists.clear(); + row_dists.reserve(spilled_results.distances.len() + ingested_results.distances.len()); + + for (row_idx, dist) in spilled_results + .row_range + .clone() + .zip(spilled_results.distances.iter()) + { + row_dists.push(( + RowSelector::FromSpilled { + batch_idx: spilled_batch_idx, + row_idx, + }, + *dist, + )); + } + for (row_idx, dist) in ingested_results + .row_range + .clone() + .zip(ingested_results.distances.iter()) + { + row_dists.push((RowSelector::FromIngested { row_idx }, *dist)); + } + + truncate_row_selectors_to_top_k(row_dists, &self.top_k_distances, k, include_tie_breaker); + for (row_selector, dist) in row_dists.iter() { + self.rows_selector.push(*row_selector); + self.dist_array_builder.append_value(*dist); + } + row_dists.len() + } + + fn build_spilled_batch( + &mut self, + ingested_results: Option<&StructArray>, + spilled_results: &[&StructArray], + ) -> Result> { + if self.index_array_builder.is_empty() { + return Ok(None); + } + + // index column: UInt64 + let index_array = Arc::new(self.index_array_builder.finish()); + + // rows column: List, dist: Float64>> + let rows_array = interleave_spill_and_ingested_rows( + ingested_results, + spilled_results, + &self.rows_selector, + )?; + self.rows_selector.clear(); + let dist_array = Arc::new(self.dist_array_builder.finish()); + let row_array_offsets_builder = std::mem::replace( + &mut self.row_array_offsets_builder, + OffsetBufferBuilder::::new(self.capacity), + ); + let row_offsets = row_array_offsets_builder.finish(); + let row_dist_array = StructArray::try_new( + self.rows_inner_fields.clone(), + vec![rows_array, dist_array], + None, + )?; + let row_dist_item_field = Arc::new(Field::new( + "item", + DataType::Struct(self.rows_inner_fields.clone()), + false, + )); + let rows_list_array = ListArray::try_new( + row_dist_item_field, + row_offsets, + Arc::new(row_dist_array), + None, + )?; + + // unfiltered_dists column: List + let unfiltered_dist_array = Arc::new(self.unfiltered_dist_array_builder.finish()); + let unfiltered_dist_offsets_builder = std::mem::replace( + &mut self.unfiltered_dist_offsets_builder, + OffsetBufferBuilder::::new(self.capacity), + ); + let unfiltered_offsets = unfiltered_dist_offsets_builder.finish(); + let unfiltered_field = Arc::new(Field::new("item", DataType::Float64, false)); + let unfiltered_list_array = ListArray::try_new( + unfiltered_field, + unfiltered_offsets, + unfiltered_dist_array, + None, + )?; + + Ok(Some(RecordBatch::try_new( + self.spill_schema.clone(), + vec![ + index_array, + Arc::new(rows_list_array), + Arc::new(unfiltered_list_array), + ], + )?)) + } + + fn build_result_batch( + &mut self, + ingested_results: Option<&StructArray>, + spilled_results: &[&StructArray], + ) -> Result> { + if self.index_array_builder.is_empty() { + return Ok(None); + } + + // Reset builders for building columns required by spilled batches. Building these columns seems to be wasted work + // when we only need to produce result batches, but it simplifies the code significantly and the performance impact is minimal. + let _ = std::mem::replace( + &mut self.index_array_builder, + UInt64Array::builder(self.capacity), + ); + let _ = std::mem::replace( + &mut self.dist_array_builder, + Float64Array::builder(self.capacity), + ); + let _ = std::mem::replace( + &mut self.row_array_offsets_builder, + OffsetBufferBuilder::::new(self.capacity), + ); + let _ = std::mem::replace( + &mut self.unfiltered_dist_array_builder, + Float64Array::builder(self.capacity), + ); + let _ = std::mem::replace( + &mut self.unfiltered_dist_offsets_builder, + OffsetBufferBuilder::::new(self.capacity), + ); + + // Build rows StructArray based on rows_selector + if self.rows_selector.is_empty() { + return Ok(None); + } + let rows_array = interleave_spill_and_ingested_rows( + ingested_results, + spilled_results, + &self.rows_selector, + )?; + self.rows_selector.clear(); + + let struct_array = rows_array.as_struct(); + + Ok(Some(RecordBatch::from(struct_array.clone()))) + } +} + +/// Create schema for spilled intermediate KNN results. The schema includes: +/// - index: UInt64 +/// - rows: List, dist: Float64>> +/// - unfiltered_dists: List (top-K unfiltered distances so far) +fn create_spill_schema(result_schema: SchemaRef) -> SchemaRef { + let index_field = Field::new("index", DataType::UInt64, false); + let rows_inner_fields = create_rows_inner_fields(&result_schema); + let row_dist_item_field = Field::new("item", DataType::Struct(rows_inner_fields), false); + let rows_field = Field::new("rows", DataType::List(Arc::new(row_dist_item_field)), false); + let unfiltered_dists_field = Field::new( + "unfiltered_dists", + DataType::List(Arc::new(Field::new("item", DataType::Float64, false))), + false, + ); + Arc::new(Schema::new(vec![ + index_field, + rows_field, + unfiltered_dists_field, + ])) +} + +fn create_rows_inner_fields(result_schema: &Schema) -> Fields { + let row_field = Field::new( + "row", + DataType::Struct(result_schema.fields().clone()), + false, + ); + let dist_field = Field::new("dist", DataType::Float64, false); + vec![row_field, dist_field].into() +} + +fn interleave_spill_and_ingested_rows( + ingested_results: Option<&StructArray>, + spilled_results: &[&StructArray], + rows_selector: &[RowSelector], +) -> Result { + // Build rows StructArray based on rows_selector + let ingested_array_index = spilled_results.len(); + let mut indices = Vec::with_capacity(rows_selector.len()); + for selector in rows_selector { + match selector { + RowSelector::FromSpilled { batch_idx, row_idx } => { + indices.push((*batch_idx, *row_idx)); + } + RowSelector::FromIngested { row_idx } => { + if ingested_results.is_none() { + return sedona_internal_err!( + "Ingested results array is None when trying to access ingested rows" + ); + } + indices.push((ingested_array_index, *row_idx)); + } + } + } + + let mut results_arrays: Vec<&dyn Array> = Vec::with_capacity(ingested_array_index + 1); + for spilled_array in spilled_results { + results_arrays.push(spilled_array); + } + if let Some(ingested_results) = ingested_results { + results_arrays.push(ingested_results); + } + let rows_array = interleave(&results_arrays, &indices).map_err(|e| arrow_datafusion_err!(e))?; + Ok(rows_array) +} + +fn merge_unfiltered_topk(k: usize, prev: &[f64], new: &[f64], top_k: &mut Vec) { + top_k.clear(); + if k == 0 { + return; + } + top_k.reserve(prev.len() + new.len()); + top_k.extend_from_slice(prev); + top_k.extend_from_slice(new); + + // Keep only the K smallest distances, sorted. + if top_k.len() > k { + let kth = k - 1; + top_k.select_nth_unstable_by(kth, |a, b| a.total_cmp(b)); + top_k.truncate(k); + } + top_k.sort_by(|a, b| a.total_cmp(b)); +} + +fn truncate_row_selectors_to_top_k( + row_dist_vec: &mut Vec<(RowSelector, f64)>, + top_k_distances: &[f64], + k: usize, + include_tie_breaker: bool, +) { + let Some(kth_distance) = top_k_distances.last() else { + row_dist_vec.clear(); + return; + }; + + let distance_threshold = if include_tie_breaker { + // The distance threshold is slightly looser when including tie breakers, please + // refer to `SpatialIndex::query_knn` for more details. + *kth_distance + DISTANCE_TOLERANCE + } else { + *kth_distance + }; + + row_dist_vec.sort_unstable_by(|(_, l_dist), (_, r_dist)| l_dist.total_cmp(r_dist)); + + // Keep only the row selectors within distance_threshold + let mut kept_rows = 0; + for (_, dist) in row_dist_vec.iter() { + if kept_rows >= k && !include_tie_breaker { + break; + } + if *dist <= distance_threshold { + kept_rows += 1; + } else { + break; + } + } + + row_dist_vec.truncate(kept_rows); + + // If the last distance D in top_k_distances has N ties, and include_tie_breaker is false, we + // need to make sure that the kept rows with distance D should not exceed N, otherwise we'll + // incorrectly have extra rows kept. + // To fix this, we need to count how many rows have distance equal to the last distance, + // and make sure we only keep that many rows among the kept rows with that distance. + if !include_tie_breaker { + let last_distance = *kth_distance; + let num_ties_in_topk = top_k_distances + .iter() + .rev() + .take_while(|d| **d == last_distance) + .count(); + + let num_ties_in_kept = row_dist_vec + .iter() + .rev() + .take_while(|(_, d)| *d == last_distance) + .count(); + + if num_ties_in_kept > num_ties_in_topk { + let to_remove = num_ties_in_kept - num_ties_in_topk; + let new_len = row_dist_vec.len() - to_remove; + row_dist_vec.truncate(new_len); + } + } +} + +#[cfg(test)] +mod test { + use arrow::compute::take_record_batch; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + use rstest::rstest; + + use super::*; + + #[test] + fn test_knn_results_array_iterator() { + // KNNResultArray with 4 probe rows: P1000, P1001, P1002, P1004. + // P1002 has no filtered results. + let array = KNNResultArray::new(UnprocessedKNNResultBatch::new( + StructArray::new_empty_fields(7, None), + vec![1000, 1000, 1001, 1001, 1001, 1004, 1004], + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + vec![ + 1000, 1000, 1000, 1001, 1001, 1001, 1002, 1002, 1002, 1004, 1004, 1004, + ], + vec![1.0, 2.0, 3.0, 3.0, 4.0, 5.0, 7.0, 8.0, 9.0, 6.0, 7.0, 8.0], + )); + + let mut iter = KNNProbeResultIterator::new(&array); + + let res0 = iter.next().unwrap(); + assert_eq!(res0.probe_row_index, 1000); + assert_eq!(res0.row_range, 0..2); + assert_eq!(res0.distances, &[1.0, 2.0]); + assert_eq!(res0.unfiltered_distances, &[1.0, 2.0, 3.0]); + + let res1 = iter.next().unwrap(); + assert_eq!(res1.probe_row_index, 1001); + assert_eq!(res1.row_range, 2..5); + assert_eq!(res1.distances, &[3.0, 4.0, 5.0]); + assert_eq!(res1.unfiltered_distances, &[3.0, 4.0, 5.0]); + + let res2 = iter.next().unwrap(); + assert_eq!(res2.probe_row_index, 1002); + assert_eq!(res2.row_range, 5..5); + assert!(res2.distances.is_empty()); + assert_eq!(res2.unfiltered_distances, &[7.0, 8.0, 9.0]); + + let res3 = iter.next().unwrap(); + assert_eq!(res3.probe_row_index, 1004); + assert_eq!(res3.row_range, 5..7); + assert_eq!(res3.distances, &[6.0, 7.0]); + assert_eq!(res3.unfiltered_distances, &[6.0, 7.0, 8.0]); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_knn_results_array_iterator_empty() { + let array = KNNResultArray::new(UnprocessedKNNResultBatch::new( + StructArray::new_empty_fields(0, None), + vec![], + vec![], + vec![], + vec![], + )); + + let mut iter = KNNProbeResultIterator::new(&array); + assert!(iter.next().is_none()); + } + + #[test] + fn test_knn_results_array_iterator_no_filtered() { + let array = KNNResultArray::new(UnprocessedKNNResultBatch::new( + StructArray::new_empty_fields(0, None), + vec![], + vec![], + vec![0, 0, 0, 3, 3], + vec![1.0, 2.0, 3.0, 4.0, 5.0], + )); + + let mut iter = KNNProbeResultIterator::new(&array); + + let res0 = iter.next().unwrap(); + assert_eq!(res0.probe_row_index, 0); + assert_eq!(res0.row_range, 0..0); + assert!(res0.distances.is_empty()); + assert_eq!(res0.unfiltered_distances, &[1.0, 2.0, 3.0]); + + let res1 = iter.next().unwrap(); + assert_eq!(res1.probe_row_index, 3); + assert_eq!(res1.row_range, 0..0); + assert!(res1.distances.is_empty()); + assert_eq!(res1.unfiltered_distances, &[4.0, 5.0]); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_knn_results_array_iterator_all_kept() { + let array = KNNResultArray::new(UnprocessedKNNResultBatch::new( + StructArray::new_empty_fields(5, None), + vec![0, 0, 0, 3, 3], + vec![1.0, 2.0, 3.0, 4.0, 5.0], + vec![0, 0, 0, 3, 3], + vec![1.0, 2.0, 3.0, 4.0, 5.0], + )); + + let mut iter = KNNProbeResultIterator::new(&array); + let res0 = iter.next().unwrap(); + assert_eq!(res0.probe_row_index, 0); + assert_eq!(res0.row_range, 0..3); + assert_eq!(res0.distances, &[1.0, 2.0, 3.0]); + assert_eq!(res0.unfiltered_distances, &[1.0, 2.0, 3.0]); + + let res1 = iter.next().unwrap(); + assert_eq!(res1.probe_row_index, 3); + assert_eq!(res1.row_range, 3..5); + assert_eq!(res1.distances, &[4.0, 5.0]); + assert_eq!(res1.unfiltered_distances, &[4.0, 5.0]); + assert!(iter.next().is_none()); + } + + #[test] + fn test_knn_results_array_iterator_no_dup() { + let indices = vec![0, 1, 3, 4, 6]; + let array = KNNResultArray::new(UnprocessedKNNResultBatch::new( + StructArray::new_empty_fields(5, None), + indices.clone(), + vec![0.0, 1.0, 3.0, 4.0, 6.0], + vec![0, 1, 2, 3, 4, 5, 6], + vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + )); + + let mut iter = KNNProbeResultIterator::new(&array); + for k in 0..7 { + let res0 = iter.next().unwrap(); + assert_eq!(res0.probe_row_index, k); + assert_eq!(res0.unfiltered_distances, &[k as f64]); + + if let Ok(pos) = indices.binary_search(&k) { + assert_eq!(res0.row_range, pos..(pos + 1)); + assert_eq!(res0.distances, &[k as f64]); + } else { + assert!(res0.row_range.is_empty()); + assert!(res0.distances.is_empty()); + } + } + + assert!(iter.next().is_none()); + } + + #[test] + fn test_spill_index_array_advance_to() { + let mut arr = SpilledBatchIndexArray::new(UInt64Array::from(vec![1, 2, 3, 6, 8, 10])); + + let res = arr.advance_to(0); + assert_eq!(res.skipped_range, 0..0); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: false + } + )); + + let res = arr.advance_to(1); + assert_eq!(res.skipped_range, 0..1); + assert!(matches!(res.found_target, HasFoundIndex::Found)); + + // Repeatedly advance to the same target won't move the cursor, and will return NotFound. + let res = arr.advance_to(1); + assert_eq!(res.skipped_range, 1..1); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: false + } + )); + + let res = arr.advance_to(2); + assert_eq!(res.skipped_range, 1..2); + assert!(matches!(res.found_target, HasFoundIndex::Found)); + + // Advance to a missing target within the array, indexes less than the target are skipped. + // The cursor stops at the first index greater than the target. + let res = arr.advance_to(4); + assert_eq!(res.skipped_range, 2..3); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: false + } + )); + + let res = arr.advance_to(6); + assert_eq!(res.skipped_range, 3..4); + assert!(matches!(res.found_target, HasFoundIndex::Found)); + + let res = arr.advance_to(10); + assert_eq!(res.skipped_range, 4..6); + assert!(matches!(res.found_target, HasFoundIndex::Found)); + + // Advance to a target larger than the last index, the cursor moves to the end, + // and signals to load the next batch. + let res = arr.advance_to(11); + assert_eq!(res.skipped_range, 6..6); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: true + } + )); + } + + #[test] + fn test_spill_index_array_advance_to_skip_all() { + let mut arr = SpilledBatchIndexArray::new(UInt64Array::from(vec![1, 2, 3, 6, 8, 10])); + + let res = arr.advance_to(100); + assert_eq!(res.skipped_range, 0..6); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: true + } + )); + } + + #[test] + fn test_spill_index_array_advance_to_end() { + let mut arr = SpilledBatchIndexArray::new(UInt64Array::from(vec![1, 2, 3, 6, 8, 10])); + + let res = arr.advance_to(3); + assert_eq!(res.skipped_range, 0..3); + assert!(matches!(res.found_target, HasFoundIndex::Found)); + + // Advance to the end by specifying usize::MAX as target. + let res = arr.advance_to(usize::MAX); + assert_eq!(res.skipped_range, 3..6); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: true + } + )); + } + + #[test] + fn test_spill_index_array_advance_empty() { + let mut arr = SpilledBatchIndexArray::new(UInt64Array::from(Vec::::new())); + + let res = arr.advance_to(0); + assert_eq!(res.skipped_range, 0..0); + assert!(matches!( + res.found_target, + HasFoundIndex::NotFound { + should_load_next_batch: true + } + )); + } + + #[test] + fn test_merge_unfiltered_topk() { + let mut top_k = Vec::new(); + + // Normal cases + merge_unfiltered_topk(3, &[1.0, 3.0, 5.0], &[2.0, 4.0, 6.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0]); + merge_unfiltered_topk(3, &[5.0, 3.0, 1.0], &[2.0, 6.0, 4.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0]); + merge_unfiltered_topk(5, &[1.0, 3.0], &[2.0, 4.0, 5.0, 6.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0, 4.0, 5.0]); + merge_unfiltered_topk(5, &[5.0, 3.0, 1.0], &[2.0, 6.0, 4.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0, 4.0, 5.0]); + + // k equals total number of distances + merge_unfiltered_topk(6, &[5.0, 3.0, 1.0], &[2.0, 6.0, 4.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + // k larger than total number of distances + merge_unfiltered_topk(10, &[5.0, 3.0, 1.0], &[2.0, 6.0, 4.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + // k is zero (usually this does not happen in practice) + merge_unfiltered_topk(0, &[1.0, 3.0], &[2.0, 4.0], &mut top_k); + assert_eq!(top_k, Vec::::new()); + merge_unfiltered_topk(0, &[], &[], &mut top_k); + assert_eq!(top_k, Vec::::new()); + + // one side is empty + merge_unfiltered_topk(2, &[], &[2.0, 1.0], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0]); + merge_unfiltered_topk(2, &[2.0, 1.0], &[], &mut top_k); + assert_eq!(top_k, vec![1.0, 2.0]); + } + + fn create_dummy_row_selectors(dists: &[f64]) -> Vec<(RowSelector, f64)> { + dists + .iter() + .enumerate() + .map(|(i, d)| (RowSelector::FromIngested { row_idx: i }, *d)) + .collect() + } + + fn count_dist(v: &[(RowSelector, f64)], target: f64) -> usize { + v.iter().filter(|(_, d)| *d == target).count() + } + + #[test] + fn test_truncate_row_selectors_for_empty_unfiltered_top_k() { + let mut row_dist_vec = create_dummy_row_selectors(&[1.0, 2.0, 3.0]); + truncate_row_selectors_to_top_k(&mut row_dist_vec, &[], 3, false); + assert!(row_dist_vec.is_empty()); + } + + #[test] + fn test_truncate_row_selectors_no_dup() { + // Keep at most K rows within distance threshold. + let k = 3; + let top_k_distances = vec![1.0, 2.0, 3.0]; + let mut row_dist_vec = create_dummy_row_selectors(&[3.0, 2.0, 4.0, 1.0, 5.0]); + + truncate_row_selectors_to_top_k(&mut row_dist_vec, &top_k_distances, k, false); + + assert_eq!(row_dist_vec.len(), 3); + assert!(row_dist_vec.iter().all(|(_, d)| *d <= 3.0)); + assert_eq!(count_dist(&row_dist_vec, 3.0), 1); + } + + #[test] + fn test_truncate_row_selectors_handle_last_ties() { + // top_k_distances has last distance 4.0 with only 1 tie. + // Filtered results can contain more 4.0 rows than allowed; we must trim them. + let k = 5; + let top_k_distances = vec![1.0, 2.0, 3.0, 3.0, 4.0]; + let mut row_dist_vec = create_dummy_row_selectors(&[4.0, 1.0, 4.0, 2.0, 4.0, 10.0]); + + truncate_row_selectors_to_top_k(&mut row_dist_vec, &top_k_distances, k, false); + + assert!(row_dist_vec.iter().all(|(_, d)| *d <= 4.0)); + assert!(row_dist_vec.len() <= k); + assert_eq!(count_dist(&row_dist_vec, 4.0), 1); + assert_eq!(count_dist(&row_dist_vec, 1.0), 1); + assert_eq!(count_dist(&row_dist_vec, 2.0), 1); + + // top_k_distances has last distance 4.0 with 2 ties. + // If we keep more than 2 rows with 4.0, we must discard some from the tail. + let k = 4; + let top_k_distances = vec![1.0, 2.0, 4.0, 4.0]; + let mut row_dist_vec = create_dummy_row_selectors(&[4.0, 4.0, 4.0, 1.0]); + + truncate_row_selectors_to_top_k(&mut row_dist_vec, &top_k_distances, k, false); + + assert!(row_dist_vec.iter().all(|(_, d)| *d <= 4.0)); + assert!(row_dist_vec.len() <= k); + assert_eq!(count_dist(&row_dist_vec, 4.0), 2); + assert_eq!(count_dist(&row_dist_vec, 1.0), 1); + + // Keep fewer ties than in top_k_distances should not trigger any trimming. + let k = 5; + let top_k_distances = vec![1.0, 2.0, 3.0, 5.0, 5.0]; + let mut row_dist_vec = create_dummy_row_selectors(&[5.0, 1.0]); + + truncate_row_selectors_to_top_k(&mut row_dist_vec, &top_k_distances, k, false); + + assert_eq!(row_dist_vec.len(), 2); + assert_eq!(count_dist(&row_dist_vec, 5.0), 1); + assert_eq!(count_dist(&row_dist_vec, 1.0), 1); + } + + #[test] + fn test_truncate_row_selectors_include_tie_breakers() { + let k = 3; + let top_k_distances = vec![1.0, 2.0, 3.0]; + let tol_half = DISTANCE_TOLERANCE / 2.0; + + let mut row_dist_vec = + create_dummy_row_selectors(&[3.0, 1.0, 3.0, 2.0, 3.0 + tol_half, 4.0]); + truncate_row_selectors_to_top_k(&mut row_dist_vec, &top_k_distances, k, true); + + // Should keep all <= 3.0 + DISTANCE_TOLERANCE (i.e. not limited by k). + assert!(row_dist_vec.len() > k); + assert!(row_dist_vec + .iter() + .all(|(_, d)| *d <= 3.0 + DISTANCE_TOLERANCE)); + assert_eq!(count_dist(&row_dist_vec, 1.0), 1); + assert_eq!(count_dist(&row_dist_vec, 2.0), 1); + assert_eq!(count_dist(&row_dist_vec, 3.0), 2); + assert_eq!(count_dist(&row_dist_vec, 3.0 + tol_half), 1); + } + + #[derive(Clone, PartialEq, Debug)] + struct FuzzTestKNNResult { + query_id: usize, + knn_objects: Vec, + } + + #[derive(Clone, PartialEq, Debug)] + struct FuzzKNNResultObject { + object_id: usize, + distance: f64, + is_kept: bool, + } + + fn create_fuzz_test_data_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("query_id", DataType::UInt64, false), + Field::new("object_id", DataType::UInt64, false), + ])) + } + + fn create_fuzz_test_data( + k: usize, + num_rows: usize, + kept_prob: f64, + rng: &mut StdRng, + ) -> Vec { + let mut test_data = Vec::with_capacity(num_rows); + let mut next_object_id = 0; + for query_id in 0..num_rows { + // Generate K objects + let knn_objects = (next_object_id..next_object_id + k) + .map(|object_id| FuzzKNNResultObject { + object_id, + distance: rng.random_range(1.0..10.0), + is_kept: rng.random_bool(kept_prob), + }) + .collect::>(); + next_object_id += k; + + test_data.push(FuzzTestKNNResult { + query_id, + knn_objects, + }); + } + test_data + } + + fn partition_fuzz_test_data( + test_data: &[FuzzTestKNNResult], + num_partitions: usize, + kept_prob: f64, + rng: &mut StdRng, + ) -> Vec> { + let mut partitions: Vec> = vec![Vec::new(); num_partitions]; + let mut next_object_id = test_data + .iter() + .flat_map(|r| r.knn_objects.iter()) + .map(|o| o.object_id) + .max() + .unwrap_or(0) + + 1; + for result in test_data.iter() { + // Split the knn_objects into partitions, randomly mix in some objects with large distances + let distance_threshold = result + .knn_objects + .iter() + .map(|o| o.distance) + .reduce(f64::max) + .unwrap_or(0.0); + + let k = result.knn_objects.len(); + if k == 0 { + for partition in partitions.iter_mut() { + partition.push(FuzzTestKNNResult { + query_id: result.query_id, + knn_objects: Vec::new(), + }); + } + continue; + } + + let mut extended_knn_objects = result.knn_objects.clone(); + for _ in 0..((num_partitions - 1) * k) { + extended_knn_objects.push(FuzzKNNResultObject { + object_id: next_object_id, + distance: distance_threshold + rng.random_range(1..10) as f64, + is_kept: rng.random_bool(kept_prob), + }); + next_object_id += 1; + } + extended_knn_objects.shuffle(rng); + + for (part_idx, chunk) in extended_knn_objects.chunks(k).enumerate() { + partitions[part_idx].push(FuzzTestKNNResult { + query_id: result.query_id, + knn_objects: chunk.to_vec(), + }); + } + } + partitions + } + + fn merge_partitioned_test_data( + partitioned_data: &[Vec], + ) -> Vec { + let num_queries = partitioned_data[0].len(); + let mut merged_results = Vec::with_capacity(num_queries); + for query_idx in 0..num_queries { + let mut knn_objects = Vec::new(); + for partition in partitioned_data.iter() { + knn_objects.extend_from_slice(&partition[query_idx].knn_objects); + } + merged_results.push(FuzzTestKNNResult { + query_id: partitioned_data[0][query_idx].query_id, + knn_objects, + }); + } + merged_results + } + + fn compute_expected_results( + test_data: &[FuzzTestKNNResult], + k: usize, + include_tie_breaker: bool, + ) -> Vec<(usize, Vec)> { + let mut expected_results = Vec::with_capacity(test_data.len()); + for result in test_data.iter() { + let mut knn_objects = result.knn_objects.clone(); + + // Take top K objects first + knn_objects.sort_by(|a, b| { + a.distance + .total_cmp(&b.distance) + .then(a.object_id.cmp(&b.object_id)) + }); + if let Some(kth_distance) = knn_objects.get(k.saturating_sub(1)).map(|o| o.distance) { + if include_tie_breaker { + let distance_threshold = kth_distance + DISTANCE_TOLERANCE; + knn_objects.retain(|o| o.distance <= distance_threshold); + } else { + knn_objects.truncate(k); + } + } else { + knn_objects.clear(); + } + + // Filter the results to only kept objects + let kept_objects = knn_objects.into_iter().filter(|o| o.is_kept).collect(); + + expected_results.push((result.query_id, kept_objects)); + } + expected_results + } + + fn is_fuzz_test_data_equivalent( + test_data: &[FuzzTestKNNResult], + partitioned_test_data: &[Vec], + k: usize, + include_tie_breaker: bool, + ) -> bool { + let merged_partitioned_test_data = merge_partitioned_test_data(partitioned_test_data); + let expected_results = compute_expected_results(test_data, k, include_tie_breaker); + let partitioned_results = + compute_expected_results(&merged_partitioned_test_data, k, include_tie_breaker); + expected_results == partitioned_results + } + + fn ingest_partitioned_fuzz_test_data( + knn_result_spiller: &mut KNNResultsMerger, + partitioned_test_data: &[Vec], + query_group_size: usize, + batch_size: usize, + ) -> Result> { + let mut merged_record_batches = Vec::new(); + for (i_partition, partition) in partitioned_test_data.iter().enumerate() { + if i_partition != 0 { + let is_last_partition = i_partition == partitioned_test_data.len() - 1; + knn_result_spiller.rotate(is_last_partition)?; + } + + let mut start_offset = 0; + for partition_chunk in partition.chunks(query_group_size) { + let res_batches = ingest_fuzz_test_data_segment( + knn_result_spiller, + partition_chunk, + start_offset, + batch_size, + )?; + merged_record_batches.extend(res_batches); + start_offset += partition_chunk.len(); + } + + if let Some(batch) = knn_result_spiller.produce_batch_until(start_offset)? { + merged_record_batches.push(batch); + } + } + Ok(merged_record_batches) + } + + fn ingest_fuzz_test_data_segment( + knn_result_spiller: &mut KNNResultsMerger, + test_data: &[FuzzTestKNNResult], + start_offset: usize, + batch_size: usize, + ) -> Result> { + // Assemble the test_data into one RecordBatch + let mut query_id_builder = UInt64Array::builder(test_data.len()); + let mut object_id_builder = UInt64Array::builder(test_data.len()); + let mut indices = Vec::new(); + let mut distances = Vec::new(); + let mut is_kept = Vec::new(); + for (idx, result) in test_data.iter().enumerate() { + for obj in result.knn_objects.iter() { + query_id_builder.append_value(result.query_id as u64); + object_id_builder.append_value(obj.object_id as u64); + indices.push(idx + start_offset); + distances.push(obj.distance); + is_kept.push(obj.is_kept); + } + } + let query_id_array = Arc::new(query_id_builder.finish()); + let object_id_array = Arc::new(object_id_builder.finish()); + let schema = create_fuzz_test_data_schema(); + let knn_result_batch = RecordBatch::try_new(schema, vec![query_id_array, object_id_array])?; + + // Break the record batch into smaller batches and ingest them + let mut merged_record_batches = Vec::new(); + for start in (0..knn_result_batch.num_rows()).step_by(batch_size) { + let end = (start + batch_size).min(knn_result_batch.num_rows()); + let batch = knn_result_batch.slice(start, end - start); + + let unfiltered_distances = distances[start..end].to_vec(); + let unfiltered_indices = indices[start..end].to_vec(); + let is_kept_slice = &is_kept[start..end]; + + // Find local indices for kept rows + let kept_indices_within_batch: Vec = is_kept_slice + .iter() + .enumerate() + .filter_map(|(i, &kept)| if kept { Some(i) } else { None }) + .collect(); + + let kept_indices_array = UInt64Array::from( + kept_indices_within_batch + .iter() + .map(|&i| i as u64) + .collect::>(), + ); + let batch = take_record_batch(&batch, &kept_indices_array).unwrap(); + let filtered_distances: Vec = kept_indices_within_batch + .iter() + .map(|&i| unfiltered_distances[i]) + .collect(); + let filtered_indices = kept_indices_within_batch + .iter() + .map(|&i| unfiltered_indices[i]) + .collect::>(); + + let res = knn_result_spiller.ingest( + batch, + filtered_indices, + filtered_distances, + unfiltered_indices, + unfiltered_distances, + )?; + if let Some(res_batch) = res { + merged_record_batches.push(res_batch); + } + } + + Ok(merged_record_batches) + } + + fn assert_merged_knn_result_is_correct( + batch: &RecordBatch, + partitioned_test_data: &[Vec], + k: usize, + include_tie_breaker: bool, + ) { + let merged_test_data = merge_partitioned_test_data(partitioned_test_data); + let expected_results = compute_expected_results(&merged_test_data, k, include_tie_breaker); + let mut expected_results: Vec<(u64, u64)> = expected_results + .iter() + .flat_map(|(query_id, objects)| { + objects + .iter() + .map(move |obj| (*query_id as u64, obj.object_id as u64)) + }) + .collect(); + expected_results.sort(); + + let mut actual_results: Vec<(u64, u64)> = Vec::new(); + let query_id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let object_id_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + actual_results.push((query_id_array.value(i), object_id_array.value(i))); + } + actual_results.sort(); + + assert_eq!(expected_results, actual_results); + } + + #[allow(clippy::too_many_arguments)] + fn fuzz_test_knn_results_merger( + rng: &mut StdRng, + num_rows: usize, + num_partitions: usize, + kept_prob: f64, + k: usize, + include_tie_breaker: bool, + query_group_size: usize, + target_batch_size: usize, + ) -> Result<()> { + assert!(num_partitions > 1); + + for _ in 0..10 { + let test_data = create_fuzz_test_data(k, num_rows, kept_prob, rng); + let partitioned_test_data = + partition_fuzz_test_data(&test_data, num_partitions, kept_prob, rng); + assert!(is_fuzz_test_data_equivalent( + &test_data, + &partitioned_test_data, + k, + include_tie_breaker + )); + + fuzz_test_knn_results_merger_using_partitioned_data( + &partitioned_test_data, + k, + include_tie_breaker, + query_group_size, + target_batch_size, + )?; + } + + Ok(()) + } + + fn fuzz_test_knn_results_merger_using_partitioned_data( + partitioned_test_data: &[Vec], + k: usize, + include_tie_breaker: bool, + query_group_size: usize, + target_batch_size: usize, + ) -> Result<()> { + let test_data_schema = create_fuzz_test_data_schema(); + let runtime_env = Arc::new(RuntimeEnv::default()); + let metrics_set = ExecutionPlanMetricsSet::new(); + let spill_metrics = SpillMetrics::new(&metrics_set, 0); + let mut knn_results_merger = KNNResultsMerger::try_new( + k, + include_tie_breaker, + target_batch_size, + runtime_env, + SpillCompression::Uncompressed, + Arc::clone(&test_data_schema), + spill_metrics, + )?; + + let batches = ingest_partitioned_fuzz_test_data( + &mut knn_results_merger, + partitioned_test_data, + query_group_size, + target_batch_size, + )?; + let batch = concat_batches(&test_data_schema, batches.iter()) + .map_err(|e| arrow_datafusion_err!(e))?; + assert_merged_knn_result_is_correct(&batch, partitioned_test_data, k, include_tie_breaker); + Ok(()) + } + + #[rstest] + fn test_knn_results_merger( + #[values(1, 10, 13, 50, 51, 1000)] target_batch_size: usize, + #[values(false, true)] include_tie_breaker: bool, + ) { + let mut rng = StdRng::seed_from_u64(target_batch_size as u64); + fuzz_test_knn_results_merger( + &mut rng, + 100, + 4, + 0.5, + 5, + include_tie_breaker, + 30, + target_batch_size, + ) + .unwrap(); + } + + #[test] + fn test_knn_results_merger_empty_query_side() { + let mut rng = StdRng::seed_from_u64(42); + fuzz_test_knn_results_merger(&mut rng, 0, 3, 1.0, 10, false, 100, 33).unwrap(); + } + + #[test] + fn test_knn_results_merger_all_filtered() { + let mut rng = StdRng::seed_from_u64(42); + fuzz_test_knn_results_merger(&mut rng, 100, 3, 0.0, 10, false, 50, 33).unwrap(); + } + + #[test] + fn test_knn_results_merger_no_knn_results() { + let empty_test_data = (0..100) + .map(|query_id| FuzzTestKNNResult { + query_id, + knn_objects: Vec::new(), + }) + .collect::>(); + let partitioned_test_data = vec![empty_test_data.clone(); 3]; + fuzz_test_knn_results_merger_using_partitioned_data( + &partitioned_test_data, + 5, + false, + 50, + 33, + ) + .unwrap(); + } + + #[test] + fn test_knn_results_merger_k_is_zero() { + let empty_test_data = (0..100) + .map(|query_id| FuzzTestKNNResult { + query_id, + knn_objects: Vec::new(), + }) + .collect::>(); + let partitioned_test_data = vec![empty_test_data.clone(); 3]; + fuzz_test_knn_results_merger_using_partitioned_data( + &partitioned_test_data, + 0, + false, + 50, + 33, + ) + .unwrap(); + } + + #[test] + fn test_knn_result_merger_with_empty_partitions() { + let k = 5; + let include_tie_breaker = false; + let num_rows = 100; + let num_partitions = 3; + let kept_prob = 0.5; + let mut rng = StdRng::seed_from_u64(42); + + let test_data = create_fuzz_test_data(k, num_rows, kept_prob, &mut rng); + let partitioned_test_data = + partition_fuzz_test_data(&test_data, num_partitions, kept_prob, &mut rng); + for i in 0..(num_partitions + 1) { + let empty_test_data = (0..num_rows) + .map(|query_id| FuzzTestKNNResult { + query_id, + knn_objects: Vec::new(), + }) + .collect::>(); + + // Insert a partition with no knn results at position i + let mut test_data_with_empty_partition = partitioned_test_data.clone(); + test_data_with_empty_partition.insert(i, empty_test_data); + + assert!(is_fuzz_test_data_equivalent( + &test_data, + &test_data_with_empty_partition, + k, + include_tie_breaker + )); + + let query_group_size = 30; + let target_batch_size = 33; + fuzz_test_knn_results_merger_using_partitioned_data( + &test_data_with_empty_partition, + k, + include_tie_breaker, + query_group_size, + target_batch_size, + ) + .unwrap(); + } + } + + #[rstest] + fn test_knn_results_merger_with_missing_probe_rows( + #[values(1, 10, 13, 50, 51, 1000)] target_batch_size: usize, + ) { + let k = 5; + let include_tie_breaker = true; + let num_rows = 20; + let num_partitions = 3; + let kept_prob = 0.5; + let mut rng = StdRng::seed_from_u64(target_batch_size as u64); + + let mut test_data = create_fuzz_test_data(k, num_rows, kept_prob, &mut rng); + + // Remove the query results of some probe rows randomly + for result in test_data.iter_mut() { + if rng.random_bool(0.1) { + result.knn_objects.clear(); + } + } + + let mut partitioned_test_data = + partition_fuzz_test_data(&test_data, num_partitions, kept_prob, &mut rng); + assert!(is_fuzz_test_data_equivalent( + &test_data, + &partitioned_test_data, + k, + include_tie_breaker + )); + + // Randomly remove some probe rows from each partition + for partition in partitioned_test_data.iter_mut() { + for result in partition.iter_mut() { + if rng.random_bool(0.1) { + result.knn_objects.clear(); + } + } + } + + let query_group_size = 30; + let target_batch_size = 33; + fuzz_test_knn_results_merger_using_partitioned_data( + &partitioned_test_data, + k, + include_tie_breaker, + query_group_size, + target_batch_size, + ) + .unwrap(); + } +} diff --git a/rust/sedona-spatial-join/src/stream.rs b/rust/sedona-spatial-join/src/stream.rs index 177e7b0ff..378ac6c10 100644 --- a/rust/sedona-spatial-join/src/stream.rs +++ b/rust/sedona-spatial-join/src/stream.rs @@ -17,13 +17,16 @@ use arrow::array::BooleanBufferBuilder; use arrow::compute::interleave_record_batch; use arrow_array::{UInt32Array, UInt64Array}; +use datafusion::config::SpillCompression; use datafusion::prelude::SessionConfig; use datafusion_common::{JoinSide, Result}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::JoinType; use datafusion_physical_plan::joins::utils::StatefulStreamResult; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_physical_plan::metrics::{ + self, ExecutionPlanMetricsSet, MetricBuilder, SpillMetrics, +}; use datafusion_physical_plan::{handle_state, RecordBatchStream, SendableRecordBatchStream}; use futures::future::BoxFuture; use futures::stream::StreamExt; @@ -43,6 +46,7 @@ use crate::index::{partitioned_index_provider::PartitionedIndexProvider, Spatial use crate::operand_evaluator::create_operand_evaluator; use crate::partitioning::SpatialPartition; use crate::prepare::SpatialJoinComponents; +use crate::probe::knn_results_merger::KNNResultsMerger; use crate::probe::partitioned_stream_provider::PartitionedProbeStreamProvider; use crate::probe::ProbeStreamMetrics; use crate::spatial_predicate::SpatialPredicate; @@ -82,8 +86,12 @@ pub(crate) struct SpatialJoinStream { runtime_env: Arc, /// Options for the spatial join options: SpatialJoinOptions, + /// Metrics set + metrics_set: ExecutionPlanMetricsSet, /// Target output batch size target_output_batch_size: usize, + /// Spill compression codec + spill_compression: SpillCompression, /// Once future for the shared partitioned index provider once_fut_spatial_join_components: OnceFut, /// Once async for the provider, disposed by the last finished stream @@ -106,6 +114,12 @@ pub(crate) struct SpatialJoinStream { /// This is used for outer joins to ensure that we only emit unmatched rows from the Multi /// partition once, after all regular partitions have been processed. visited_multi_probe_side: Option>>, + /// KNN results merger. Only used for partitioned KNN join. This value is Some when this spatial join stream + /// is for KNN join and the number of partitions is greater than 1, except when in the + /// [SpatialJoinStreamState::ProcessProbeBatch] state. The `knn_results_merger` will be moved into the + /// [SpatialJoinBatchIterator] when processing a probe batch, and moved back to here when the iterator is + /// complete. + knn_results_merger: Option>, /// Current offset in the probe side partition probe_offset: usize, } @@ -133,6 +147,7 @@ impl SpatialJoinStream { once_async_spatial_join_components: Arc>>>, ) -> Self { let target_output_batch_size = session_config.batch_size(); + let spill_compression = session_config.spill_compression(); let sedona_options = session_config .options() .extensions @@ -162,7 +177,9 @@ impl SpatialJoinStream { state: SpatialJoinStreamState::WaitPrepareSpatialJoinComponents, runtime_env, options: sedona_options.spatial_join, + metrics_set: metrics.clone(), target_output_batch_size, + spill_compression, once_fut_spatial_join_components, once_async_spatial_join_components, index_provider: None, @@ -173,6 +190,7 @@ impl SpatialJoinStream { num_regular_partitions: None, spatial_predicate: on.clone(), visited_multi_probe_side: None, + knn_results_merger: None, probe_offset: 0, } } @@ -410,6 +428,20 @@ impl SpatialJoinStream { return Poll::Ready(Ok(StatefulStreamResult::Continue)); } + if num_partitions > 1 { + if let SpatialPredicate::KNearestNeighbors(knn) = &self.spatial_predicate { + self.knn_results_merger = Some(Box::new(KNNResultsMerger::try_new( + knn.k as usize, + self.options.knn_include_tie_breakers, + self.target_output_batch_size, + Arc::clone(&self.runtime_env), + self.spill_compression, + self.schema.clone(), + SpillMetrics::new(&self.metrics_set, self.probe_partition_id), + )?)); + } + } + self.state = SpatialJoinStreamState::WaitBuildIndex(0, true); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -578,6 +610,7 @@ impl SpatialJoinStream { // Check if iterator is complete if iterator.is_complete() { + self.knn_results_merger = iterator.take_knn_results_merger(); self.state = SpatialJoinStreamState::FetchProbeBatch(partition_desc); } else { // Iterator is not complete, continue processing the current probe batch @@ -592,6 +625,7 @@ impl SpatialJoinStream { } None => { // Iterator finished, move to the next probe batch + self.knn_results_merger = iterator.take_knn_results_merger(); self.state = SpatialJoinStreamState::FetchProbeBatch(partition_desc); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -729,6 +763,12 @@ impl SpatialJoinStream { let next_partition_id = current_partition_id + 1; + if let Some(merger) = self.knn_results_merger.as_deref_mut() { + if next_partition_id < num_regular_partitions { + merger.rotate(next_partition_id == num_regular_partitions - 1)?; + } + } + if next_partition_id >= num_regular_partitions { if is_last_stream { let mut once_async = self.once_async_spatial_join_components.lock(); @@ -787,6 +827,10 @@ impl SpatialJoinStream { _ => JoinSide::Left, }; + // Move out the knn_results_merger to the iterator, we'll move it back when the iterator is complete + // by calling `SpatialJoinBatchIterator::take_knn_results_merger`. + let knn_results_merger = std::mem::take(&mut self.knn_results_merger); + let iterator = SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams { schema: self.schema.clone(), filter: self.filter.clone(), @@ -803,6 +847,7 @@ impl SpatialJoinStream { probe_offset, produce_unmatched_probe_rows: is_last_build_partition, probe_evaluated_batch: Arc::new(probe_evaluated_batch), + knn_results_merger, })?; Ok(Box::new(iterator)) } @@ -896,10 +941,20 @@ struct ProbeProgress { /// Cursor of the position in the `build_batch_positions` and `probe_indices` vectors /// for tracking the progress of producing joined batches pos: usize, + /// KNN-specific progress. Only used for KNN join. + knn: Option, +} + +struct KNNProbeProgress { + /// Accumulated comparable (e.g. squared) distances of the KNN results. + /// Should have the same length as `build_batch_positions`. + distances: Vec, + /// KNN results merger. + knn_results_merger: Box, } /// Type alias for a tuple of build and probe indices slices -type BuildAndProbeIndices<'a> = (&'a [(i32, i32)], &'a [u32]); +type BuildAndProbeIndices<'a> = (&'a [(i32, i32)], &'a [u32], Option<&'a [f64]>); impl ProbeProgress { fn indices_for_next_batch( @@ -943,9 +998,13 @@ impl ProbeProgress { let slice_end = (self.pos + max_batch_size).min(end); let build_indices = &self.build_batch_positions[self.pos..slice_end]; let probe_indices = &self.probe_indices[self.pos..slice_end]; + let distances = self + .knn + .as_ref() + .map(|knn| &knn.distances[self.pos..slice_end]); self.pos = slice_end; - Some((build_indices, probe_indices)) + Some((build_indices, probe_indices, distances)) } fn next_probe_range(&mut self, probe_indices: &[u32]) -> Range { @@ -1011,10 +1070,12 @@ pub(crate) struct SpatialJoinBatchIteratorParams { pub probe_offset: usize, /// Whether to emit unmatched probe rows (used for right outer joins). pub produce_unmatched_probe_rows: bool, + /// The KNN result merger for merging KNN results across partitions. + /// Only available when running KNN join with multiple build partitions. + pub knn_results_merger: Option>, } impl SpatialJoinBatchIterator { - /// Create a new iterator for a single probe-side evaluated batch. pub(crate) fn new(params: SpatialJoinBatchIteratorParams) -> Result { Ok(Self { schema: params.schema, @@ -1038,11 +1099,14 @@ impl SpatialJoinBatchIterator { build_batch_positions: Vec::new(), probe_indices: Vec::new(), pos: 0, + knn: params.knn_results_merger.map(|merger| KNNProbeProgress { + distances: Vec::new(), + knn_results_merger: merger, + }), }), }) } - /// Produce the next joined output batch, or `Ok(None)` when this probe batch is fully processed. pub async fn next_batch(&mut self) -> Result> { let progress_opt = std::mem::take(&mut self.progress); let mut progress = progress_opt.expect("Progress should be available"); @@ -1149,12 +1213,13 @@ impl SpatialJoinBatchIterator { let use_spheroid = knn_predicate.use_spheroid; let include_tie_breakers = self.options.knn_include_tie_breakers; - let join_result_metrics = self.spatial_index.query_knn( + let join_result_metrics = self.spatial_index.query_knn_with_distance( wkb, k, use_spheroid, include_tie_breakers, &mut progress.build_batch_positions, + progress.knn.as_mut().map(|knn| &mut knn.distances), )?; progress.probe_indices.extend(std::iter::repeat_n( @@ -1176,6 +1241,12 @@ impl SpatialJoinBatchIterator { progress.probe_indices.len() == progress.build_batch_positions.len(), "Probe indices and build batch positions length should match" ); + if let Some(knn) = &progress.knn { + assert!( + knn.distances.len() == progress.probe_indices.len(), + "Probe indices and distances length should match" + ); + } progress.current_probe_idx += 1; // Early exit if we have enough results @@ -1188,31 +1259,66 @@ impl SpatialJoinBatchIterator { } fn produce_result_batch(&self, progress: &mut ProbeProgress) -> Result> { - let Some((build_indices, probe_indices)) = + let need_merge_knn_results = progress.knn.is_some(); + + let Some((build_indices, probe_indices, distances)) = progress.indices_for_next_batch(self.build_side, self.join_type, self.max_batch_size) else { // No more results to produce return Ok(None); }; - let (build_partial_batch, build_indices_array, probe_indices_array) = - self.produce_filtered_indices(build_indices, probe_indices.to_vec())?; + let (build_partial_batch, build_indices_array, probe_indices_array, filtered_distances) = + self.produce_filtered_indices(build_indices, probe_indices.to_vec(), distances)?; + + // Prepare unfiltered indices and distances for KNN joins. This has to be done before calling + // progress.next_probe_range to make the borrow checker happy. + let (unfiltered_probe_indices, unfiltered_distances) = if need_merge_knn_results { + (probe_indices.to_vec(), distances.map(|v| v.to_vec())) + } else { + (Vec::new(), None) + }; // Produce the final joined batch - if probe_indices_array.is_empty() { + let batch = if !probe_indices_array.is_empty() { + let probe_indices = probe_indices_array.values().as_ref(); + let probe_range = progress.next_probe_range(probe_indices); + self.build_joined_batch( + &build_partial_batch, + build_indices_array, + probe_indices_array.clone(), + probe_range, + )? + } else if need_merge_knn_results { + // For KNN joins, it's possible that after filtering there is no matched result. + // In this case, we still need to call merge.ingest to update the K-nearest-so-far distances. + RecordBatch::new_empty(self.schema.clone()) + } else { return Ok(None); - } - let probe_indices = probe_indices_array.values().as_ref(); - let probe_range = progress.next_probe_range(probe_indices); - let batch = self.build_joined_batch( - &build_partial_batch, - build_indices_array, - probe_indices_array.clone(), - probe_range, - )?; + }; - if batch.num_rows() > 0 { - Ok(Some(batch)) + let batch_opt = if let Some(knn) = progress.knn.as_mut() { + let probe_indices_slice = probe_indices_array.values().as_ref(); + let unfiltered_distances = unfiltered_distances.unwrap_or(Vec::new()); + knn.knn_results_merger.ingest( + batch, + probe_indices_slice + .iter() + .map(|i| (*i as usize) + self.offset_in_partition) + .collect(), + filtered_distances.unwrap_or(Vec::new()), + unfiltered_probe_indices + .iter() + .map(|i| (*i as usize) + self.offset_in_partition) + .collect(), + unfiltered_distances, + )? + } else { + Some(batch) + }; + + if batch_opt.iter().any(|b| b.num_rows() > 0) { + Ok(batch_opt) } else { Ok(None) } @@ -1230,6 +1336,20 @@ impl SpatialJoinBatchIterator { assert_eq!(progress.current_probe_idx, num_rows); assert_eq!(progress.pos, progress.probe_indices.len()); + // For partitioned KNN joins, flush any pending buffered probe index first. + // If this produces a batch, return it and let the caller poll again. + if let Some(knn) = progress.knn.as_mut() { + let end_offset_in_partition = self.offset_in_partition + num_rows; + if let Some(batch) = knn + .knn_results_merger + .produce_batch_until(end_offset_in_partition)? + { + if batch.num_rows() > 0 { + return Ok(Some(batch)); + } + } + } + let Some(probe_range) = progress.last_probe_range(num_rows) else { return Ok(None); }; @@ -1252,6 +1372,9 @@ impl SpatialJoinBatchIterator { // Move everything after `pos` to the front progress.build_batch_positions.drain(0..progress.pos); progress.probe_indices.drain(0..progress.pos); + if let Some(knn) = &mut progress.knn { + knn.distances.drain(0..progress.pos); + } progress.pos = 0; } @@ -1264,6 +1387,15 @@ impl SpatialJoinBatchIterator { self.is_complete_inner(progress) } + pub fn take_knn_results_merger(&mut self) -> Option> { + assert!(self.is_complete(), "Iterator should be complete"); + let progress = self + .progress + .as_mut() + .expect("Progress should be available"); + progress.knn.take().map(|knn| knn.knn_results_merger) + } + fn is_complete_inner(&self, progress: &ProbeProgress) -> bool { progress.last_produced_probe_idx >= self.probe_evaluated_batch.batch.num_rows() as i64 } @@ -1272,7 +1404,8 @@ impl SpatialJoinBatchIterator { &self, build_indices: &[(i32, i32)], probe_indices: Vec, - ) -> Result<(RecordBatch, UInt64Array, UInt32Array)> { + distances: Option<&[f64]>, + ) -> Result<(RecordBatch, UInt64Array, UInt32Array, Option>)> { let PartialBuildBatch { batch: partial_build_batch, indices: build_indices, @@ -1280,16 +1413,17 @@ impl SpatialJoinBatchIterator { } = self.assemble_partial_build_batch(build_indices)?; let probe_indices = UInt32Array::from(probe_indices); - let (build_indices, probe_indices) = match &self.filter { + let (build_indices, probe_indices, filtered_distances) = match &self.filter { Some(filter) => apply_join_filter_to_indices( &partial_build_batch, &self.probe_evaluated_batch.batch, build_indices, probe_indices, + distances, filter, self.build_side, )?, - None => (build_indices, probe_indices), + None => (build_indices, probe_indices, distances.map(|d| d.to_vec())), }; // set the build side bitmap @@ -1303,7 +1437,12 @@ impl SpatialJoinBatchIterator { } } - Ok((partial_build_batch, build_indices, probe_indices)) + Ok(( + partial_build_batch, + build_indices, + probe_indices, + filtered_distances, + )) } fn build_joined_batch( @@ -1919,10 +2058,11 @@ mod tests { build_batch_positions, probe_indices: probe_indices.to_vec(), pos: 0, + knn: None, }; let mut produced_probe_indices: Vec = Vec::new(); loop { - let Some((_, probe_indices)) = + let Some((_, probe_indices, _)) = progress.indices_for_next_batch(JoinSide::Left, join_type, max_batch_size) else { break; diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs b/rust/sedona-spatial-join/src/utils/join_utils.rs index fcfa576fc..bb6b3790e 100644 --- a/rust/sedona-spatial-join/src/utils/join_utils.rs +++ b/rust/sedona-spatial-join/src/utils/join_utils.rs @@ -275,11 +275,12 @@ pub(crate) fn apply_join_filter_to_indices( probe_batch: &RecordBatch, build_indices: UInt64Array, probe_indices: UInt32Array, + distances: Option<&[f64]>, filter: &JoinFilter, build_side: JoinSide, -) -> Result<(UInt64Array, UInt32Array)> { +) -> Result<(UInt64Array, UInt32Array, Option>)> { if build_indices.is_empty() && probe_indices.is_empty() { - return Ok((build_indices, probe_indices)); + return Ok((build_indices, probe_indices, distances.map(|_| Vec::new()))); }; let intermediate_batch = build_batch_from_indices( @@ -300,9 +301,28 @@ pub(crate) fn apply_join_filter_to_indices( let left_filtered = compute::filter(&build_indices, mask)?; let right_filtered = compute::filter(&probe_indices, mask)?; + + let filtered_distances = if let Some(distances) = distances { + debug_assert_eq!( + distances.len(), + build_indices.len(), + "distances length should match indices length" + ); + let dist_array = arrow_array::Float64Array::from(distances.to_vec()); + let filtered = compute::filter(&dist_array, mask)?; + let filtered = filtered + .as_any() + .downcast_ref::() + .expect("filtered distance array should be Float64Array"); + Some(filtered.values().to_vec()) + } else { + None + }; + Ok(( downcast_array(left_filtered.as_ref()), downcast_array(right_filtered.as_ref()), + filtered_distances, )) }