diff --git a/rust/sedona-common/src/option.rs b/rust/sedona-common/src/option.rs index fcd692fb4..bc74acf74 100644 --- a/rust/sedona-common/src/option.rs +++ b/rust/sedona-common/src/option.rs @@ -70,6 +70,13 @@ config_namespace! { /// Include tie-breakers in KNN join results when there are tied distances pub knn_include_tie_breakers: bool, default = false + + /// The minimum number of geometry pairs per chunk required to enable parallel + /// refinement during the spatial join operation. When the refinement phase has + /// fewer geometry pairs than this threshold, it will run sequentially instead + /// of spawning parallel tasks. Higher values reduce parallelization overhead + /// for small datasets, while lower values enable more fine-grained parallelism. + pub parallel_refinement_chunk_size: usize, default = 8192 } } diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index 5cdea16de..43b73290c 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -1135,6 +1135,24 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_parallel_refinement_for_large_candidate_set() -> Result<()> { + let ((left_schema, left_partitions), (right_schema, right_partitions)) = + create_test_data_with_size_range((1.0, 50.0), WKB_GEOMETRY)?; + + for max_batch_size in [10, 30, 100] { + let options = SpatialJoinOptions { + execution_mode: ExecutionMode::PrepareNone, + parallel_refinement_chunk_size: 10, + ..Default::default() + }; + test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size, + "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY L.id, R.id").await?; + } + + Ok(()) + } + async fn test_with_join_types(join_type: JoinType) -> Result { let ((left_schema, left_partitions), (right_schema, right_partitions)) = create_test_data_with_empty_partitions()?; diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs b/rust/sedona-spatial-join/src/index/spatial_index.rs index 83a1a754d..6f3e00d02 100644 --- a/rust/sedona-spatial-join/src/index/spatial_index.rs +++ b/rust/sedona-spatial-join/src/index/spatial_index.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use std::{ + ops::Range, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; +use datafusion_common_runtime::JoinSet; use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation}; use float_next_after::NextAfter; use geo::BoundingRect; @@ -44,7 +48,7 @@ use crate::{ knn_adapter::{KnnComponents, SedonaKnnAdapter}, IndexQueryResult, QueryResultMetrics, }, - operand_evaluator::{create_operand_evaluator, OperandEvaluator}, + operand_evaluator::{create_operand_evaluator, distance_value_at, OperandEvaluator}, refine::{create_refiner, IndexQueryResultRefiner}, spatial_predicate::SpatialPredicate, utils::concurrent_reservation::ConcurrentReservation, @@ -54,6 +58,7 @@ use sedona_common::{option::SpatialJoinOptions, sedona_internal_err, ExecutionMo pub struct SpatialIndex { pub(crate) schema: SchemaRef, + pub(crate) options: SpatialJoinOptions, /// The spatial predicate evaluator for the spatial predicate. pub(crate) evaluator: Arc, @@ -125,6 +130,7 @@ impl SpatialIndex { .then(|| KnnComponents::new(0, &[], memory_pool.clone()).unwrap()); Self { schema, + options, evaluator, refiner, refiner_reservation, @@ -178,6 +184,7 @@ impl SpatialIndex { /// # Returns /// * `JoinResultMetrics` containing the number of actual matches (`count`) and the number /// of candidates from the filter phase (`candidate_count`) + #[allow(unused)] pub(crate) fn query( &self, probe_wkb: &Wkb, @@ -409,6 +416,179 @@ impl SpatialIndex { }) } + /// Query the spatial index with a batch of probe geometries to find matching build-side geometries. + /// + /// This method iterates over the probe geometries in the given range of the evaluated batch. + /// For each probe geometry, it performs the two-phase spatial join query: + /// 1. **Filter phase**: Uses the R-tree index with the probe geometry's bounding rectangle + /// to quickly identify candidate geometries. + /// 2. **Refinement phase**: Evaluates the exact spatial predicate on candidates to determine + /// actual matches. + /// + /// # Arguments + /// * `evaluated_batch` - The batch containing probe geometries and their bounding rectangles + /// * `range` - The range of rows in the evaluated batch to process. + /// * `max_result_size` - The maximum number of results to collect before stopping. If the + /// number of results exceeds this limit, the method returns early. + /// * `build_batch_positions` - Output vector that will be populated with (batch_idx, row_idx) + /// pairs for each matching build-side geometry. + /// * `probe_indices` - Output vector that will be populated with the probe row index (in + /// `evaluated_batch`) for each match appended to `build_batch_positions`. + /// This means the probe index is repeated `N` times when a probe geometry produces `N` matches, + /// keeping `probe_indices.len()` in sync with `build_batch_positions.len()`. + /// + /// # Returns + /// * A tuple containing: + /// - `QueryResultMetrics`: Aggregated metrics (total matches and candidates) for the processed rows + /// - `usize`: The index of the next row to process (exclusive end of the processed range) + pub(crate) async fn query_batch( + self: &Arc, + evaluated_batch: &Arc, + range: Range, + max_result_size: usize, + build_batch_positions: &mut Vec<(i32, i32)>, + probe_indices: &mut Vec, + ) -> Result<(QueryResultMetrics, usize)> { + if range.is_empty() { + return Ok(( + QueryResultMetrics { + count: 0, + candidate_count: 0, + }, + range.start, + )); + } + + let rects = evaluated_batch.rects(); + let dist = evaluated_batch.distance(); + let mut total_candidates_count = 0; + let mut total_count = 0; + let mut current_row_idx = range.start; + for row_idx in range { + current_row_idx = row_idx; + let Some(probe_rect) = rects[row_idx] else { + continue; + }; + + let min = probe_rect.min(); + let max = probe_rect.max(); + let mut candidates = self.rtree.search(min.x, min.y, max.x, max.y); + if candidates.is_empty() { + continue; + } + + let Some(probe_wkb) = evaluated_batch.wkb(row_idx) else { + return sedona_internal_err!( + "Failed to get WKB for row {} in evaluated batch", + row_idx + ); + }; + + // Sort and dedup candidates to avoid duplicate results when we index one geometry + // using several boxes. + candidates.sort_unstable(); + candidates.dedup(); + + let distance = match dist { + Some(dist_array) => distance_value_at(dist_array, row_idx)?, + None => None, + }; + + // Refine the candidates retrieved from the r-tree index by evaluating the actual spatial predicate + let refine_chunk_size = self.options.parallel_refinement_chunk_size; + if refine_chunk_size == 0 || candidates.len() < refine_chunk_size * 2 { + // For small candidate sets, use refine synchronously + let metrics = + self.refine(probe_wkb, &candidates, &distance, build_batch_positions)?; + probe_indices.extend(std::iter::repeat_n(row_idx as u32, metrics.count)); + total_count += metrics.count; + total_candidates_count += metrics.candidate_count; + } else { + // For large candidate sets, spawn several tasks to parallelize refinement + let (metrics, positions) = self + .refine_concurrently( + evaluated_batch, + row_idx, + &candidates, + distance, + refine_chunk_size, + ) + .await?; + build_batch_positions.extend(positions); + probe_indices.extend(std::iter::repeat_n(row_idx as u32, metrics.count)); + total_count += metrics.count; + total_candidates_count += metrics.candidate_count; + } + + if total_count >= max_result_size { + break; + } + } + + let end_idx = current_row_idx + 1; + Ok(( + QueryResultMetrics { + count: total_count, + candidate_count: total_candidates_count, + }, + end_idx, + )) + } + + async fn refine_concurrently( + self: &Arc, + evaluated_batch: &Arc, + row_idx: usize, + candidates: &[u32], + distance: Option, + refine_chunk_size: usize, + ) -> Result<(QueryResultMetrics, Vec<(i32, i32)>)> { + let mut join_set = JoinSet::new(); + for (i, chunk) in candidates.chunks(refine_chunk_size).enumerate() { + let cloned_evaluated_batch = Arc::clone(evaluated_batch); + let chunk = chunk.to_vec(); + let index_ref = Arc::clone(self); + join_set.spawn(async move { + let Some(probe_wkb) = cloned_evaluated_batch.wkb(row_idx) else { + return ( + i, + sedona_internal_err!( + "Failed to get WKB for row {} in evaluated batch", + row_idx + ), + ); + }; + let mut local_positions: Vec<(i32, i32)> = Vec::with_capacity(chunk.len()); + let res = index_ref.refine(probe_wkb, &chunk, &distance, &mut local_positions); + (i, res.map(|r| (r, local_positions))) + }); + } + + // Collect the results in order + let mut refine_results = Vec::with_capacity(join_set.len()); + refine_results.resize_with(join_set.len(), || None); + while let Some(res) = join_set.join_next().await { + let (chunk_idx, refine_res) = + res.map_err(|e| DataFusionError::External(Box::new(e)))?; + let (metrics, positions) = refine_res?; + refine_results[chunk_idx] = Some((metrics, positions)); + } + + let mut total_metrics = QueryResultMetrics { + count: 0, + candidate_count: 0, + }; + let mut all_positions = Vec::with_capacity(candidates.len()); + for res in refine_results { + let (metrics, positions) = res.expect("All chunks should be processed"); + total_metrics.count += metrics.count; + total_metrics.candidate_count += metrics.candidate_count; + all_positions.extend(positions); + } + + Ok((total_metrics, all_positions)) + } + fn refine( &self, probe_wkb: &Wkb, @@ -1232,9 +1412,6 @@ mod tests { assert!(build_positions.len() <= 3); assert!(result.count > 0); assert!(result.count <= 3); - - println!("KNN Geometry test - found {} results", result.count); - println!("Result positions: {build_positions:?}"); } #[test] @@ -1316,8 +1493,6 @@ mod tests { // Should return results assert!(!build_positions.is_empty()); - println!("KNN with mixed geometries: {build_positions:?}"); - // Should work with mixed geometry types assert!(result.count > 0); } @@ -1518,4 +1693,291 @@ mod tests { assert_eq!(result.candidate_count, 0); assert!(build_positions.is_empty()); } + + async fn setup_index_for_batch_test( + build_geoms: &[Option<&str>], + options: SpatialJoinOptions, + ) -> Arc { + let memory_pool = Arc::new(GreedyMemoryPool::new(100 * 1024 * 1024)); + let metrics = SpatialJoinBuildMetrics::default(); + let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("left", 0)), + Arc::new(Column::new("right", 0)), + SpatialRelationType::Intersects, + )); + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut builder = SpatialIndexBuilder::new( + schema, + spatial_predicate, + options, + JoinType::Inner, + 1, + memory_pool, + metrics, + ) + .unwrap(); + + let geom_array = create_array(build_geoms, &WKB_GEOMETRY); + let batch = RecordBatch::try_new( + Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])), + vec![Arc::new(geom_array.clone())], + ) + .unwrap(); + let evaluated_batch = EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_array, &WKB_GEOMETRY).unwrap(), + }; + + builder.add_batch(evaluated_batch).unwrap(); + Arc::new(builder.finish().unwrap()) + } + + fn create_probe_batch(probe_geoms: &[Option<&str>]) -> Arc { + let geom_array = create_array(probe_geoms, &WKB_GEOMETRY); + let batch = RecordBatch::try_new( + Arc::new(arrow_schema::Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])), + vec![Arc::new(geom_array.clone())], + ) + .unwrap(); + Arc::new(EvaluatedBatch { + batch, + geom_array: EvaluatedGeometryArray::try_new(geom_array, &WKB_GEOMETRY).unwrap(), + }) + } + + #[tokio::test] + async fn test_query_batch_empty_results() { + let build_geoms = &[Some("POINT (0 0)"), Some("POINT (1 1)")]; + let index = setup_index_for_batch_test(build_geoms, SpatialJoinOptions::default()).await; + + // Probe with geometries that don't intersect + let probe_geoms = &[Some("POINT (10 10)"), Some("POINT (20 20)")]; + let probe_batch = create_probe_batch(probe_geoms); + + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + 0..2, + usize::MAX, + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + + assert_eq!(metrics.count, 0); + assert_eq!(build_batch_positions.len(), 0); + assert_eq!(probe_indices.len(), 0); + assert_eq!(next_idx, 2); + } + + #[tokio::test] + async fn test_query_batch_max_result_size() { + let build_geoms = &[ + Some("POINT (0 0)"), + Some("POINT (0 0)"), + Some("POINT (0 0)"), + ]; + let index = setup_index_for_batch_test(build_geoms, SpatialJoinOptions::default()).await; + + // Probe with geometry that intersects all 3 + let probe_geoms = &[Some("POINT (0 0)"), Some("POINT (0 0)")]; + let probe_batch = create_probe_batch(probe_geoms); + + // Case 1: Max result size is large enough + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + 0..2, + 10, + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + assert_eq!(metrics.count, 6); // 2 probes * 3 matches + assert_eq!(next_idx, 2); + assert_eq!(probe_indices, vec![0, 0, 0, 1, 1, 1]); + + // Case 2: Max result size is small (stops after first probe) + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + 0..2, + 2, // Stop after 2 results + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + + // It should process the first probe, find 3 matches. + // Since 3 >= 2, it should stop. + assert_eq!(metrics.count, 3); + assert_eq!(next_idx, 1); // Only processed 1 probe + assert_eq!(probe_indices, vec![0, 0, 0]); + } + + #[tokio::test] + async fn test_query_batch_parallel_refinement() { + // Create enough build geometries to trigger parallel refinement + // We need candidates.len() >= chunk_size * 2 + // Let's set chunk_size = 2, so we need >= 4 candidates. + let build_geoms = vec![Some("POINT (0 0)"); 10]; + let options = SpatialJoinOptions { + parallel_refinement_chunk_size: 2, + ..Default::default() + }; + + let index = setup_index_for_batch_test(&build_geoms, options).await; + + // Probe with a geometry that intersects all build geometries + let probe_geoms = &[Some("POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))")]; + let probe_batch = create_probe_batch(probe_geoms); + + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + 0..1, + usize::MAX, + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + + assert_eq!(metrics.count, 10); + assert_eq!(build_batch_positions.len(), 10); + assert_eq!(probe_indices, vec![0; 10]); + assert_eq!(next_idx, 1); + } + + #[tokio::test] + async fn test_query_batch_empty_range() { + let build_geoms = &[Some("POINT (0 0)")]; + let index = setup_index_for_batch_test(build_geoms, SpatialJoinOptions::default()).await; + let probe_geoms = &[Some("POINT (0 0)"), Some("POINT (0 0)")]; + let probe_batch = create_probe_batch(probe_geoms); + + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + + // Query with empty range + for empty_ranges in [0..0, 1..1, 2..2] { + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + empty_ranges.clone(), + usize::MAX, + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + + assert_eq!(metrics.count, 0); + assert_eq!(next_idx, empty_ranges.end); + } + } + + #[tokio::test] + async fn test_query_batch_range_offset() { + let build_geoms = &[Some("POINT (0 0)"), Some("POINT (1 1)")]; + let index = setup_index_for_batch_test(build_geoms, SpatialJoinOptions::default()).await; + + // Probe with 3 geometries: + // 0: POINT (0 0) - matches build[0] (should be skipped) + // 1: POINT (0 0) - matches build[0] + // 2: POINT (1 1) - matches build[1] + let probe_geoms = &[ + Some("POINT (0 0)"), + Some("POINT (0 0)"), + Some("POINT (1 1)"), + ]; + let probe_batch = create_probe_batch(probe_geoms); + + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + + // Query with range 1..3 (skipping the first probe) + let (metrics, next_idx) = index + .query_batch( + &probe_batch, + 1..3, + usize::MAX, + &mut build_batch_positions, + &mut probe_indices, + ) + .await + .unwrap(); + + assert_eq!(metrics.count, 2); + assert_eq!(next_idx, 3); + + // probe_indices should contain indices relative to the batch start (1 and 2) + assert_eq!(probe_indices, vec![1, 2]); + + // build_batch_positions should contain matches for probe 1 and probe 2 + // probe 1 matches build 0 (0, 0) + // probe 2 matches build 1 (0, 1) + // Note: build_batch_positions contains (batch_idx, row_idx) + // Since we have 1 batch, batch_idx is 0. + assert_eq!(build_batch_positions, vec![(0, 0), (0, 1)]); + } + + #[tokio::test] + async fn test_query_batch_zero_parallel_refinement_chunk_size() { + let build_geoms = &[ + Some("POINT (0 0)"), + Some("POINT (0 0)"), + Some("POINT (0 0)"), + ]; + let options = SpatialJoinOptions { + // force synchronous refinement + parallel_refinement_chunk_size: 0, + ..Default::default() + }; + + let index = setup_index_for_batch_test(build_geoms, options).await; + let probe_geoms = &[Some("POINT (0 0)")]; + let probe_batch = create_probe_batch(probe_geoms); + + let mut build_batch_positions = Vec::new(); + let mut probe_indices = Vec::new(); + + let result = index + .query_batch( + &probe_batch, + 0..1, + 10, + &mut build_batch_positions, + &mut probe_indices, + ) + .await; + + assert!(result.is_ok()); + let (metrics, _) = result.unwrap(); + assert_eq!(metrics.count, 3); + } } diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs index 41d7fbd6f..a9b08d7a9 100644 --- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs +++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs @@ -274,6 +274,7 @@ impl SpatialIndexBuilder { Ok(SpatialIndex { schema: self.schema, + options: self.options, evaluator, refiner, refiner_reservation, diff --git a/rust/sedona-spatial-join/src/stream.rs b/rust/sedona-spatial-join/src/stream.rs index f4b182445..4a01e6ef8 100644 --- a/rust/sedona-spatial-join/src/stream.rs +++ b/rust/sedona-spatial-join/src/stream.rs @@ -23,7 +23,9 @@ use datafusion_physical_plan::joins::utils::StatefulStreamResult; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use datafusion_physical_plan::{handle_state, RecordBatchStream, SendableRecordBatchStream}; +use futures::future::BoxFuture; use futures::stream::StreamExt; +use futures::FutureExt; use futures::{ready, task::Poll}; use parking_lot::Mutex; use sedona_common::sedona_internal_err; @@ -37,7 +39,7 @@ use crate::evaluated_batch::evaluated_batch_stream::evaluate::create_evaluated_p use crate::evaluated_batch::evaluated_batch_stream::SendableEvaluatedBatchStream; use crate::evaluated_batch::EvaluatedBatch; use crate::index::SpatialIndex; -use crate::operand_evaluator::{create_operand_evaluator, distance_value_at}; +use crate::operand_evaluator::create_operand_evaluator; use crate::spatial_predicate::SpatialPredicate; use crate::utils::join_utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -50,7 +52,7 @@ use sedona_common::option::SpatialJoinOptions; /// Stream for producing spatial join result batches. pub(crate) struct SpatialJoinStream { - /// Input schema + /// Schema of joined results schema: Arc, /// join filter filter: Option, @@ -165,7 +167,6 @@ impl SpatialJoinProbeMetrics { } /// This enumeration represents various states of the nested loop join algorithm. -#[derive(Debug)] #[allow(clippy::large_enum_variant)] pub(crate) enum SpatialJoinStreamState { /// The initial mode: waiting for the spatial index to be built @@ -174,7 +175,9 @@ pub(crate) enum SpatialJoinStreamState { /// fetching probe-side FetchProbeBatch, /// Indicates that we're processing a probe batch using the batch iterator - ProcessProbeBatch(SpatialJoinBatchIterator), + ProcessProbeBatch( + BoxFuture<'static, (Box, Result>)>, + ), /// Indicates that probe-side has been fully processed ExhaustedProbeSide, /// Indicates that we're processing unmatched build-side batches using an iterator @@ -197,7 +200,7 @@ impl SpatialJoinStream { handle_state!(ready!(self.fetch_probe_batch(cx))) } SpatialJoinStreamState::ProcessProbeBatch(_) => { - handle_state!(ready!(self.process_probe_batch())) + handle_state!(ready!(self.process_probe_batch(cx))) } SpatialJoinStreamState::ExhaustedProbeSide => { handle_state!(ready!(self.setup_unmatched_build_batch_processing())) @@ -227,8 +230,13 @@ impl SpatialJoinStream { let result = self.probe_stream.poll_next_unpin(cx); match result { Poll::Ready(Some(Ok(batch))) => match self.create_spatial_join_iterator(batch) { - Ok(iterator) => { - self.state = SpatialJoinStreamState::ProcessProbeBatch(iterator); + Ok(mut iterator) => { + let future = async move { + let result = iterator.next_batch().await; + (iterator, result) + } + .boxed(); + self.state = SpatialJoinStreamState::ProcessProbeBatch(future); Poll::Ready(Ok(StatefulStreamResult::Continue)) } Err(e) => Poll::Ready(Err(e)), @@ -242,54 +250,51 @@ impl SpatialJoinStream { } } - fn process_probe_batch(&mut self) -> Poll>>> { - let timer = self.join_metrics.join_time.timer(); + fn process_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let _timer = self.join_metrics.join_time.timer(); // Extract the necessary data first to avoid borrowing conflicts - let (batch_opt, is_complete) = match &mut self.state { - SpatialJoinStreamState::ProcessProbeBatch(iterator) => { - // For KNN joins, we may have swapped build/probe sides, so build_side might be Right; - // For regular joins, build_side is always Left. - let build_side = match &self.spatial_predicate { - SpatialPredicate::KNearestNeighbors(knn) => knn.probe_side.negate(), - _ => JoinSide::Left, - }; - - let batch_opt = match iterator.next_batch( - &self.schema, - self.filter.as_ref(), - self.join_type, - &self.column_indices, - build_side, - ) { - Ok(opt) => opt, - Err(e) => { - return Poll::Ready(Err(e)); - } - }; - let is_complete = iterator.is_complete(); - (batch_opt, is_complete) - } + let (mut iterator, batch_opt) = match &mut self.state { + SpatialJoinStreamState::ProcessProbeBatch(future) => match future.poll_unpin(cx) { + Poll::Ready((iterator, result)) => { + let batch_opt = match result { + Ok(opt) => opt, + Err(e) => { + return Poll::Ready(Err(e)); + } + }; + (iterator, batch_opt) + } + Poll::Pending => return Poll::Pending, + }, _ => unreachable!(), }; - let result = match batch_opt { + match batch_opt { Some(batch) => { // Check if iterator is complete - if is_complete { + if iterator.is_complete() { self.state = SpatialJoinStreamState::FetchProbeBatch; + } else { + // Iterator is not complete, continue processing the current probe batch + let future = async move { + let result = iterator.next_batch().await; + (iterator, result) + } + .boxed(); + self.state = SpatialJoinStreamState::ProcessProbeBatch(future); } - batch + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) } None => { // Iterator finished, move to next probe batch self.state = SpatialJoinStreamState::FetchProbeBatch; - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + Poll::Ready(Ok(StatefulStreamResult::Continue)) } - }; - - timer.done(); - Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) + } } fn setup_unmatched_build_batch_processing( @@ -391,7 +396,7 @@ impl SpatialJoinStream { fn create_spatial_join_iterator( &self, probe_evaluated_batch: EvaluatedBatch, - ) -> Result { + ) -> Result> { let num_rows = probe_evaluated_batch.num_rows(); self.join_metrics.probe_input_batches.add(1); self.join_metrics.probe_input_rows.add(num_rows); @@ -414,15 +419,28 @@ impl SpatialJoinStream { spatial_index.merge_probe_stats(stats); } - SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams { + // For KNN joins, we may have swapped build/probe sides, so build_side might be Right; + // For regular joins, build_side is always Left. + let build_side = match &self.spatial_predicate { + SpatialPredicate::KNearestNeighbors(knn) => knn.probe_side.negate(), + _ => JoinSide::Left, + }; + + let iterator = SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams { + schema: self.schema.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + column_indices: self.column_indices.clone(), + build_side, spatial_index: spatial_index.clone(), - probe_evaluated_batch, + probe_evaluated_batch: Arc::new(probe_evaluated_batch), join_metrics: self.join_metrics.clone(), max_batch_size: self.target_output_batch_size, probe_side_ordered: self.probe_side_ordered, spatial_predicate: self.spatial_predicate.clone(), options: self.options.clone(), - }) + })?; + Ok(Box::new(iterator)) } } @@ -454,10 +472,20 @@ struct PartialBuildBatch { /// Iterator that processes spatial join results in configurable batch sizes pub(crate) struct SpatialJoinBatchIterator { + /// Schema of the output record batches + schema: SchemaRef, + /// Optional join filter to be applied to the join results + filter: Option, + /// Type of the join operation + join_type: JoinType, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// The side of the build stream, either Left or Right + build_side: JoinSide, /// The spatial index reference spatial_index: Arc, /// The probe side batch being processed - probe_evaluated_batch: EvaluatedBatch, + probe_evaluated_batch: Arc, /// Current probe row index being processed current_probe_idx: usize, /// Join metrics for tracking performance @@ -480,8 +508,13 @@ pub(crate) struct SpatialJoinBatchIterator { /// Parameters for creating a SpatialJoinBatchIterator pub(crate) struct SpatialJoinBatchIteratorParams { + pub schema: SchemaRef, + pub filter: Option, + pub join_type: JoinType, + pub column_indices: Vec, + pub build_side: JoinSide, pub spatial_index: Arc, - pub probe_evaluated_batch: EvaluatedBatch, + pub probe_evaluated_batch: Arc, pub join_metrics: SpatialJoinProbeMetrics, pub max_batch_size: usize, pub probe_side_ordered: bool, @@ -492,6 +525,11 @@ pub(crate) struct SpatialJoinBatchIteratorParams { impl SpatialJoinBatchIterator { pub(crate) fn new(params: SpatialJoinBatchIteratorParams) -> Result { Ok(Self { + schema: params.schema, + filter: params.filter, + join_type: params.join_type, + column_indices: params.column_indices, + build_side: params.build_side, spatial_index: params.spatial_index, probe_evaluated_batch: params.probe_evaluated_batch, current_probe_idx: 0, @@ -506,28 +544,50 @@ impl SpatialJoinBatchIterator { }) } - pub fn next_batch( - &mut self, - schema: &Schema, - filter: Option<&JoinFilter>, - join_type: JoinType, - column_indices: &[ColumnIndex], - build_side: JoinSide, - ) -> Result> { - // Process probe rows incrementally until we have enough results or finish - let initial_size = self.build_batch_positions.len(); + pub async fn next_batch(&mut self) -> Result> { + if self.is_complete { + return Ok(None); + } - let geom_array = &self.probe_evaluated_batch.geom_array; - let wkbs = geom_array.wkbs(); - let rects = &geom_array.rects; - let distance = &geom_array.distance; + let last_probe_idx = self.current_probe_idx; + match &self.spatial_predicate { + SpatialPredicate::KNearestNeighbors(_) => self.probe_knn()?, + _ => self.probe_range().await?, + }; - let num_rows = wkbs.len(); + // Check if we've finished processing all probe rows + if self.current_probe_idx >= self.probe_evaluated_batch.num_rows() { + self.is_complete = true; + } - let last_probe_idx = self.current_probe_idx; + if self.current_probe_idx > last_probe_idx { + // Process the joined indices to create a RecordBatch + let probe_indices = std::mem::take(&mut self.probe_indices); + let batch = self.process_joined_indices_to_batch( + &self.build_batch_positions, + probe_indices, + &self.schema, + self.filter.as_ref(), + self.join_type, + &self.column_indices, + self.build_side, + last_probe_idx..self.current_probe_idx, + )?; + + self.build_batch_positions.clear(); + Ok(Some(batch)) + } else { + Ok(None) + } + } + + fn probe_knn(&mut self) -> Result<()> { + let geom_array = &self.probe_evaluated_batch.geom_array; + let wkbs = geom_array.wkbs(); // Process from current position until we hit batch size limit or complete - while self.current_probe_idx < num_rows && !self.is_complete { + let num_rows = wkbs.len(); + while self.current_probe_idx < num_rows { // Get WKB for current probe index let wkb_opt = &wkbs[self.current_probe_idx]; @@ -537,65 +597,40 @@ impl SpatialJoinBatchIterator { continue; }; - let dist = match distance { - Some(dist) => distance_value_at(dist, self.current_probe_idx)?, - None => None, - }; - // Handle KNN queries differently from regular spatial joins - match &self.spatial_predicate { - SpatialPredicate::KNearestNeighbors(knn_predicate) => { - // For KNN, call query_knn only once per probe geometry (not per rect) - let k = knn_predicate.k; - let use_spheroid = knn_predicate.use_spheroid; - let include_tie_breakers = self.options.knn_include_tie_breakers; - - let join_result_metrics = self.spatial_index.query_knn( - wkb, - k, - use_spheroid, - include_tie_breakers, - &mut self.build_batch_positions, - )?; - - self.probe_indices.extend(std::iter::repeat_n( - self.current_probe_idx as u32, - join_result_metrics.count, - )); - - self.join_metrics - .join_result_candidates - .add(join_result_metrics.candidate_count); - self.join_metrics - .join_result_count - .add(join_result_metrics.count); - } - _ => { - // Regular spatial join: process all rects for this probe index - let rect_opt = &rects[self.current_probe_idx]; - if let Some(rect) = rect_opt { - let join_result_metrics = self.spatial_index.query( - wkb, - rect, - &dist, - &mut self.build_batch_positions, - )?; - - self.probe_indices.extend(std::iter::repeat_n( - self.current_probe_idx as u32, - join_result_metrics.count, - )); - - self.join_metrics - .join_result_candidates - .add(join_result_metrics.candidate_count); - self.join_metrics - .join_result_count - .add(join_result_metrics.count); - } - } + if let SpatialPredicate::KNearestNeighbors(knn_predicate) = &self.spatial_predicate { + // For KNN, call query_knn only once per probe geometry (not per rect) + let k = knn_predicate.k; + let use_spheroid = knn_predicate.use_spheroid; + let include_tie_breakers = self.options.knn_include_tie_breakers; + + let join_result_metrics = self.spatial_index.query_knn( + wkb, + k, + use_spheroid, + include_tie_breakers, + &mut self.build_batch_positions, + )?; + + self.probe_indices.extend(std::iter::repeat_n( + self.current_probe_idx as u32, + join_result_metrics.count, + )); + + self.join_metrics + .join_result_candidates + .add(join_result_metrics.candidate_count); + self.join_metrics + .join_result_count + .add(join_result_metrics.count); + } else { + unreachable!("probe_knn called for non-KNN predicate"); } + assert!( + self.probe_indices.len() == self.build_batch_positions.len(), + "Probe indices and build batch positions length should match" + ); self.current_probe_idx += 1; // Early exit if we have enough results @@ -604,31 +639,37 @@ impl SpatialJoinBatchIterator { } } - // Check if we've finished processing all probe rows - if self.current_probe_idx >= num_rows { - self.is_complete = true; - } + Ok(()) + } - // Return accumulated results if we have any new ones or if we're complete - if self.build_batch_positions.len() > initial_size || self.is_complete { - // Process the joined indices to create a RecordBatch - let probe_indices = std::mem::take(&mut self.probe_indices); - let batch = self.process_joined_indices_to_batch( - &self.build_batch_positions, - probe_indices, - schema, - filter, - join_type, - column_indices, - build_side, - last_probe_idx..self.current_probe_idx, - )?; + async fn probe_range(&mut self) -> Result<()> { + let num_rows = self.probe_evaluated_batch.num_rows(); + let range = self.current_probe_idx..num_rows; - self.build_batch_positions.clear(); - Ok(Some(batch)) - } else { - Ok(None) - } + let (metrics, next_row_idx) = self + .spatial_index + .query_batch( + &self.probe_evaluated_batch, + range, + self.max_batch_size, + &mut self.build_batch_positions, + &mut self.probe_indices, + ) + .await?; + + self.current_probe_idx = next_row_idx; + + self.join_metrics + .join_result_candidates + .add(metrics.candidate_count); + self.join_metrics.join_result_count.add(metrics.count); + + assert!( + self.probe_indices.len() == self.build_batch_positions.len(), + "Probe indices and build batch positions length should match" + ); + + Ok(()) } /// Check if the iterator has finished processing