diff --git a/rust/sedona-spatial-join/Cargo.toml b/rust/sedona-spatial-join/Cargo.toml index 322ec5721..d34f7a6cb 100644 --- a/rust/sedona-spatial-join/Cargo.toml +++ b/rust/sedona-spatial-join/Cargo.toml @@ -48,6 +48,7 @@ futures = { workspace = true } pin-project-lite = { workspace = true } once_cell = { workspace = true } parking_lot = { workspace = true } +tokio = { workspace = true } geo = { workspace = true } sedona-geo-generic-alg = { workspace = true } geo-traits = { workspace = true, features = ["geo-types"] } diff --git a/rust/sedona-spatial-join/src/build_index.rs b/rust/sedona-spatial-join/src/build_index.rs index f3cbb34b1..0e2923695 100644 --- a/rust/sedona-spatial-join/src/build_index.rs +++ b/rust/sedona-spatial-join/src/build_index.rs @@ -105,7 +105,6 @@ pub async fn build_index( sedona_options.spatial_join, join_type, probe_threads_count, - Arc::clone(memory_pool), SpatialJoinBuildMetrics::new(0, &metrics), )?; index_builder.add_partitions(build_partitions).await?; diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index 50cbd171d..495518ea0 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -36,12 +36,13 @@ use parking_lot::Mutex; use sedona_common::SpatialJoinOptions; use crate::{ - build_index::build_index, - index::SpatialIndex, + prepare::{SpatialJoinComponents, SpatialJoinComponentsBuilder}, spatial_predicate::{KNNPredicate, SpatialPredicate}, stream::{SpatialJoinProbeMetrics, SpatialJoinStream}, - utils::join_utils::{asymmetric_join_output_partitioning, boundedness_from_children}, - utils::once_fut::OnceAsync, + utils::{ + join_utils::{asymmetric_join_output_partitioning, boundedness_from_children}, + once_fut::OnceAsync, + }, SedonaOptions, }; @@ -132,9 +133,10 @@ pub struct SpatialJoinExec { column_indices: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, - /// Spatial index built asynchronously on first execute() call and shared across all partitions. - /// Uses OnceAsync for lazy initialization coordinated via async runtime. - once_async_spatial_index: Arc>>>, + /// Once future for creating the partitioned index provider shared by all probe partitions. + /// This future runs only once before probing starts, and can be disposed by the last finished + /// stream so the provider does not outlive the execution plan unnecessarily. + once_async_spatial_join_components: Arc>>>, /// Indicates if this SpatialJoin was converted from a HashJoin /// When true, we preserve HashJoin's equivalence properties and partitioning converted_from_hash_join: bool, @@ -203,7 +205,7 @@ impl SpatialJoinExec { projection, metrics: Default::default(), cache, - once_async_spatial_index: Arc::new(Mutex::new(None)), + once_async_spatial_join_components: Arc::new(Mutex::new(None)), converted_from_hash_join, seed, }) @@ -431,7 +433,7 @@ impl ExecutionPlan for SpatialJoinExec { projection: self.projection.clone(), metrics: Default::default(), cache: self.cache.clone(), - once_async_spatial_index: Arc::new(Mutex::new(None)), + once_async_spatial_join_components: Arc::new(Mutex::new(None)), converted_from_hash_join: self.converted_from_hash_join, seed: self.seed, })) @@ -463,8 +465,8 @@ impl ExecutionPlan for SpatialJoinExec { let (build_plan, probe_plan) = (&self.left, &self.right); // Build the spatial index using shared OnceAsync - let once_fut_spatial_index = { - let mut once_async = self.once_async_spatial_index.lock(); + let once_fut_spatial_join_components = { + let mut once_async = self.once_async_spatial_join_components.lock(); once_async .get_or_insert(OnceAsync::default()) .try_once(|| { @@ -479,16 +481,16 @@ impl ExecutionPlan for SpatialJoinExec { let probe_thread_count = self.right.output_partitioning().partition_count(); - Ok(build_index( + let spatial_join_components_builder = SpatialJoinComponentsBuilder::new( Arc::clone(&context), build_side.schema(), - build_streams, self.on.clone(), self.join_type, probe_thread_count, self.metrics.clone(), self.seed, - )) + ); + Ok(spatial_join_components_builder.build(build_streams)) })? }; @@ -508,6 +510,7 @@ impl ExecutionPlan for SpatialJoinExec { self.maintains_input_order()[1] && self.right.output_ordering().is_some(); Ok(Box::pin(SpatialJoinStream::new( + partition, self.schema(), &self.on, self.filter.clone(), @@ -518,8 +521,8 @@ impl ExecutionPlan for SpatialJoinExec { join_metrics, sedona_options.spatial_join, target_output_batch_size, - once_fut_spatial_index, - Arc::clone(&self.once_async_spatial_index), + once_fut_spatial_join_components, + Arc::clone(&self.once_async_spatial_join_components), ))) } } @@ -556,8 +559,8 @@ impl SpatialJoinExec { let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), self.left.as_ref()); // Build the spatial index - let once_fut_spatial_index = { - let mut once_async = self.once_async_spatial_index.lock(); + let once_fut_spatial_join_components = { + let mut once_async = self.once_async_spatial_join_components.lock(); once_async .get_or_insert(OnceAsync::default()) .try_once(|| { @@ -571,16 +574,16 @@ impl SpatialJoinExec { } let probe_thread_count = probe_plan.output_partitioning().partition_count(); - Ok(build_index( + let spatial_join_components_builder = SpatialJoinComponentsBuilder::new( Arc::clone(&context), build_side.schema(), - build_streams, self.on.clone(), self.join_type, probe_thread_count, self.metrics.clone(), self.seed, - )) + ); + Ok(spatial_join_components_builder.build(build_streams)) })? }; @@ -605,6 +608,7 @@ impl SpatialJoinExec { }; Ok(Box::pin(SpatialJoinStream::new( + partition, self.schema(), &self.on, self.filter.clone(), @@ -615,8 +619,8 @@ impl SpatialJoinExec { join_metrics, sedona_options.spatial_join, target_output_batch_size, - once_fut_spatial_index, - Arc::clone(&self.once_async_spatial_index), + once_fut_spatial_join_components, + Arc::clone(&self.once_async_spatial_join_components), ))) } } diff --git a/rust/sedona-spatial-join/src/index.rs b/rust/sedona-spatial-join/src/index.rs index 55df23d56..af31b8af5 100644 --- a/rust/sedona-spatial-join/src/index.rs +++ b/rust/sedona-spatial-join/src/index.rs @@ -17,6 +17,8 @@ pub(crate) mod build_side_collector; mod knn_adapter; +pub(crate) mod memory_plan; +pub(crate) mod partitioned_index_provider; pub(crate) mod spatial_index; pub(crate) mod spatial_index_builder; diff --git a/rust/sedona-spatial-join/src/index/build_side_collector.rs b/rust/sedona-spatial-join/src/index/build_side_collector.rs index 646c6be21..d888680f1 100644 --- a/rust/sedona-spatial-join/src/index/build_side_collector.rs +++ b/rust/sedona-spatial-join/src/index/build_side_collector.rs @@ -68,6 +68,9 @@ pub(crate) struct BuildPartition { /// The size of this reservation will be used to determine the maximum size of /// each spatial partition, as well as how many spatial partitions to create. pub reservation: MemoryReservation, + + /// Metrics collected during the build side collection phase + pub metrics: CollectBuildSideMetrics, } /// A collector for evaluating the spatial expression on build side batches and collect @@ -112,6 +115,10 @@ impl CollectBuildSideMetrics { spill_metrics: SpillMetrics::new(metrics, partition), } } + + pub fn spill_metrics(&self) -> SpillMetrics { + self.spill_metrics.clone() + } } impl BuildSideBatchesCollector { @@ -147,7 +154,7 @@ impl BuildSideBatchesCollector { mut stream: SendableEvaluatedBatchStream, mut reservation: MemoryReservation, mut bbox_sampler: BoundingBoxSampler, - metrics: &CollectBuildSideMetrics, + metrics: CollectBuildSideMetrics, ) -> Result { let mut spill_writer_opt = None; let mut in_mem_batches: Vec = Vec::new(); @@ -200,7 +207,7 @@ impl BuildSideBatchesCollector { e, ); spill_writer_opt = - self.spill_in_mem_batches(&mut in_mem_batches, metrics)?; + self.spill_in_mem_batches(&mut in_mem_batches, &metrics)?; } } Some(spill_writer) => { @@ -236,7 +243,7 @@ impl BuildSideBatchesCollector { "Force spilling enabled. Spilling {} in-memory batches to disk.", in_mem_batches.len() ); - spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, metrics)?; + spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, &metrics)?; } let build_side_batch_stream: SendableEvaluatedBatchStream = match spill_writer_opt { @@ -266,6 +273,7 @@ impl BuildSideBatchesCollector { bbox_samples: bbox_sampler.into_samples(), estimated_spatial_index_memory_usage, reservation, + metrics, }) } @@ -329,7 +337,7 @@ impl BuildSideBatchesCollector { let evaluated_stream = create_evaluated_build_stream(stream, evaluator, metrics.time_taken.clone()); let result = collector - .collect(evaluated_stream, reservation, bbox_sampler, &metrics) + .collect(evaluated_stream, reservation, bbox_sampler, metrics) .await; (partition_id, result) }); @@ -378,7 +386,7 @@ impl BuildSideBatchesCollector { let evaluated_stream = create_evaluated_build_stream(stream, evaluator, metrics.time_taken.clone()); let result = self - .collect(evaluated_stream, reservation, bbox_sampler, &metrics) + .collect(evaluated_stream, reservation, bbox_sampler, metrics) .await?; results.push(result); } @@ -534,11 +542,12 @@ mod tests { let metrics = CollectBuildSideMetrics::new(0, &metrics_set); let partition = collector - .collect(stream, reservation, sampler, &metrics) + .collect(stream, reservation, sampler, metrics) .await?; let stream = partition.build_side_batch_stream; let is_external = stream.is_external(); let batches: Vec = stream.try_collect().await?; + let metrics = &partition.metrics; assert!(!is_external, "Expected in-memory batches"); assert_eq!(collect_ids(&batches), vec![0, 1, 2]); assert_eq!(partition.num_rows, 3); @@ -564,14 +573,15 @@ mod tests { let metrics = CollectBuildSideMetrics::new(0, &metrics_set); let partition = collector - .collect(stream, reservation, sampler, &metrics) + .collect(stream, reservation, sampler, metrics) .await?; let stream = partition.build_side_batch_stream; let is_external = stream.is_external(); let batches: Vec = stream.try_collect().await?; + let metrics = &partition.metrics; assert!(is_external, "Expected batches to spill to disk"); assert_eq!(collect_ids(&batches), vec![10, 11, 12]); - let spill_metrics = metrics.spill_metrics; + let spill_metrics = metrics.spill_metrics(); assert!(spill_metrics.spill_file_count.value() >= 1); assert!(spill_metrics.spilled_rows.value() >= 1); Ok(()) @@ -587,12 +597,13 @@ mod tests { let metrics = CollectBuildSideMetrics::new(0, &metrics_set); let partition = collector - .collect(stream, reservation, sampler, &metrics) + .collect(stream, reservation, sampler, metrics) .await?; assert_eq!(partition.num_rows, 0); let stream = partition.build_side_batch_stream; let is_external = stream.is_external(); let batches: Vec = stream.try_collect().await?; + let metrics = &partition.metrics; assert!(!is_external); assert!(batches.is_empty()); assert_eq!(metrics.num_batches.value(), 0); diff --git a/rust/sedona-spatial-join/src/index/memory_plan.rs b/rust/sedona-spatial-join/src/index/memory_plan.rs new file mode 100644 index 000000000..24a25c892 --- /dev/null +++ b/rust/sedona-spatial-join/src/index/memory_plan.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::max; + +use datafusion_common::{DataFusionError, Result}; + +use super::BuildPartition; + +/// The memory accounting summary of a build side partition. This is collected +/// during the build side collection phase and used to estimate the memory usage for +/// running spatial join. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) struct PartitionMemorySummary { + /// Number of rows in the partition. + pub num_rows: usize, + /// The total memory reserved when collecting this build side partition. + pub reserved_memory: usize, + /// The estimated memory usage for building the spatial index for all the data in + /// this build side partition. + pub estimated_index_memory_usage: usize, +} + +impl From<&BuildPartition> for PartitionMemorySummary { + fn from(partition: &BuildPartition) -> Self { + Self { + num_rows: partition.num_rows, + reserved_memory: partition.reservation.size(), + estimated_index_memory_usage: partition.estimated_spatial_index_memory_usage, + } + } +} + +/// A detailed plan for memory usage during spatial join execution. The spatial join +/// could be spatial-partitioned if the reserved memory is not sufficient to hold the +/// entire spatial index. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct MemoryPlan { + /// The total number of rows in the build side. + pub num_rows: usize, + /// The total memory reserved for the build side. + pub reserved_memory: usize, + /// The estimated memory usage for building the spatial index for the entire build side. + /// It could be larger than [`Self::reserved_memory`], and in that case we need to + /// partition the build side using spatial partitioning. + pub estimated_index_memory_usage: usize, + /// The memory budget for holding the spatial index. If the spatial join is partitioned, + /// this is the memory budget for holding the spatial index of a single partition. + pub memory_for_spatial_index: usize, + /// The memory budget for intermittent usage, such as buffering data during repartitioning. + pub memory_for_intermittent_usage: usize, + /// The number of spatial partitions to split the build side into. + pub num_partitions: usize, +} + +/// Compute the memory plan for running spatial join based on the memory summaries of +/// build side partitions. +pub(crate) fn compute_memory_plan(partition_summaries: I) -> Result +where + I: IntoIterator, +{ + let mut num_rows = 0; + let mut reserved_memory = 0; + let mut estimated_index_memory_usage = 0; + + for summary in partition_summaries { + num_rows += summary.num_rows; + reserved_memory += summary.reserved_memory; + estimated_index_memory_usage += summary.estimated_index_memory_usage; + } + + if reserved_memory == 0 && num_rows > 0 { + return Err(DataFusionError::ResourcesExhausted( + "Insufficient memory for spatial join".to_string(), + )); + } + + // Use 80% of reserved memory for holding the spatial index. The other 20% are reserved for + // intermittent usage like repartitioning buffers. + let memory_for_spatial_index = + calculate_memory_for_spatial_index(reserved_memory, estimated_index_memory_usage); + let memory_for_intermittent_usage = reserved_memory - memory_for_spatial_index; + + let num_partitions = if num_rows > 0 { + max( + 1, + estimated_index_memory_usage.div_ceil(memory_for_spatial_index), + ) + } else { + 1 + }; + + Ok(MemoryPlan { + num_rows, + reserved_memory, + estimated_index_memory_usage, + memory_for_spatial_index, + memory_for_intermittent_usage, + num_partitions, + }) +} + +fn calculate_memory_for_spatial_index( + reserved_memory: usize, + estimated_index_memory_usage: usize, +) -> usize { + if reserved_memory >= estimated_index_memory_usage { + // Reserved memory is sufficient to hold the entire spatial index. Make sure that + // the memory for spatial index is enough for holding the entire index. The rest + // can be used for intermittent usage. + estimated_index_memory_usage + } else { + // Reserved memory is not sufficient to hold the entire spatial index, We need to + // partition the dataset using spatial partitioning. Use 80% of reserved memory + // for holding the partitioned spatial index. The rest is used for intermittent usage. + let reserved_portion = reserved_memory.saturating_mul(80) / 100; + if reserved_portion == 0 { + reserved_memory + } else { + reserved_portion + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn summary( + num_rows: usize, + reserved_memory: usize, + estimated_usage: usize, + ) -> PartitionMemorySummary { + PartitionMemorySummary { + num_rows, + reserved_memory, + estimated_index_memory_usage: estimated_usage, + } + } + + #[test] + fn memory_plan_errors_when_no_memory_but_rows_exist() { + let err = compute_memory_plan(vec![summary(10, 0, 512)]).unwrap_err(); + assert!(matches!( + err, + DataFusionError::ResourcesExhausted(msg) if msg.contains("Insufficient memory") + )); + } + + #[test] + fn memory_plan_partitions_large_jobs() { + let plan = + compute_memory_plan(vec![summary(100, 2_000, 1_500), summary(150, 1_000, 3_500)]) + .expect("plan should succeed"); + + assert_eq!(plan.num_rows, 250); + assert_eq!(plan.reserved_memory, 3_000); + assert_eq!(plan.memory_for_spatial_index, 2_400); + assert_eq!(plan.memory_for_intermittent_usage, 600); + assert_eq!(plan.num_partitions, 3); + } + + #[test] + fn memory_plan_handles_zero_rows() { + let plan = compute_memory_plan(vec![summary(0, 0, 0)]).expect("plan should succeed"); + assert_eq!(plan.num_partitions, 1); + assert_eq!(plan.memory_for_spatial_index, 0); + assert_eq!(plan.memory_for_intermittent_usage, 0); + } + + #[test] + fn memory_plan_uses_entire_reservation_when_fraction_rounds_down() { + let plan = compute_memory_plan(vec![summary(10, 1, 1)]).expect("plan should succeed"); + assert_eq!(plan.memory_for_spatial_index, 1); + assert_eq!(plan.memory_for_intermittent_usage, 0); + } +} diff --git a/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs new file mode 100644 index 000000000..f9aeb893e --- /dev/null +++ b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs @@ -0,0 +1,602 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::{DataFusionError, Result, SharedResult}; +use datafusion_common_runtime::JoinSet; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_expr::JoinType; +use futures::StreamExt; +use parking_lot::Mutex; +use sedona_common::{sedona_internal_err, SpatialJoinOptions}; +use std::ops::DerefMut; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::evaluated_batch::evaluated_batch_stream::external::ExternalEvaluatedBatchStream; +use crate::index::BuildPartition; +use crate::partitioning::stream_repartitioner::{SpilledPartition, SpilledPartitions}; +use crate::utils::disposable_async_cell::DisposableAsyncCell; +use crate::{ + index::{SpatialIndex, SpatialIndexBuilder, SpatialJoinBuildMetrics}, + partitioning::SpatialPartition, + spatial_predicate::SpatialPredicate, +}; + +pub(crate) struct PartitionedIndexProvider { + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + metrics: SpatialJoinBuildMetrics, + + /// Data on the build side to build index for + data: BuildSideData, + + /// Async cells for indexes, one per regular partition + index_cells: Vec>>>, + + /// The memory reserved in the build side collection phase. We'll hold them until + /// we don't need to build spatial indexes. + _reservations: Vec, +} + +pub(crate) enum BuildSideData { + SinglePartition(Mutex>>), + MultiPartition(Mutex), +} + +impl PartitionedIndexProvider { + #[allow(clippy::too_many_arguments)] + pub fn new_multi_partition( + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + partitioned_spill_files: SpilledPartitions, + metrics: SpatialJoinBuildMetrics, + reservations: Vec, + ) -> Self { + let num_partitions = partitioned_spill_files.num_regular_partitions(); + let index_cells = (0..num_partitions) + .map(|_| DisposableAsyncCell::new()) + .collect(); + Self { + schema, + spatial_predicate, + options, + join_type, + probe_threads_count, + metrics, + data: BuildSideData::MultiPartition(Mutex::new(partitioned_spill_files)), + index_cells, + _reservations: reservations, + } + } + + #[allow(clippy::too_many_arguments)] + pub fn new_single_partition( + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + mut build_partitions: Vec, + metrics: SpatialJoinBuildMetrics, + ) -> Self { + let reservations = build_partitions + .iter_mut() + .map(|p| p.reservation.take()) + .collect(); + let index_cells = vec![DisposableAsyncCell::new()]; + Self { + schema, + spatial_predicate, + options, + join_type, + probe_threads_count, + metrics, + data: BuildSideData::SinglePartition(Mutex::new(Some(build_partitions))), + index_cells, + _reservations: reservations, + } + } + + pub fn new_empty( + schema: SchemaRef, + spatial_predicate: SpatialPredicate, + options: SpatialJoinOptions, + join_type: JoinType, + probe_threads_count: usize, + metrics: SpatialJoinBuildMetrics, + ) -> Self { + let build_partitions = Vec::new(); + Self::new_single_partition( + schema, + spatial_predicate, + options, + join_type, + probe_threads_count, + build_partitions, + metrics, + ) + } + + pub fn num_regular_partitions(&self) -> usize { + self.index_cells.len() + } + + pub async fn build_or_wait_for_index( + &self, + partition_id: u32, + ) -> Option>> { + let cell = match self.index_cells.get(partition_id as usize) { + Some(cell) => cell, + None => { + return Some(sedona_internal_err!( + "partition_id {} exceeds {} partitions", + partition_id, + self.index_cells.len() + )) + } + }; + if !cell.is_empty() { + return get_index_from_cell(cell).await; + } + + let res_index = { + let opt_res_index = self.maybe_build_index(partition_id).await; + match opt_res_index { + Some(res_index) => res_index, + None => { + // The build side data for building the index has already been consumed by someone else, + // we just need to wait for the task consumed the data to finish building the index. + return get_index_from_cell(cell).await; + } + } + }; + + match res_index { + Ok(idx) => { + if let Err(e) = cell.set(Ok(Arc::clone(&idx))) { + // This is probably because the cell has been disposed. No one + // will get the index from the cell so this failure is not a big deal. + log::debug!("Cannot set the index into the async cell: {:?}", e); + } + Some(Ok(idx)) + } + Err(err) => { + let err_arc = Arc::new(err); + if let Err(e) = cell.set(Err(Arc::clone(&err_arc))) { + log::debug!( + "Cannot set the index build error into the async cell: {:?}", + e + ); + } + Some(Err(DataFusionError::Shared(err_arc))) + } + } + } + + async fn maybe_build_index(&self, partition_id: u32) -> Option>> { + match &self.data { + BuildSideData::SinglePartition(build_partition_opt) => { + if partition_id != 0 { + return Some(sedona_internal_err!( + "partition_id for single-partition index is not 0" + )); + } + + // consume the build side data for building the index + let build_partition_opt = { + let mut locked = build_partition_opt.lock(); + std::mem::take(locked.deref_mut()) + }; + + let Some(build_partition) = build_partition_opt else { + // already consumed by previous attempts, the result should be present in the channel. + return None; + }; + Some(self.build_index_for_single_partition(build_partition).await) + } + BuildSideData::MultiPartition(partitioned_spill_files) => { + // consume this partition of build side data for building index + let spilled_partition = { + let mut locked = partitioned_spill_files.lock(); + let partition = SpatialPartition::Regular(partition_id); + if !locked.can_take_spilled_partition(partition) { + // already consumed by previous attempts, the result should be present in the channel. + return None; + } + match locked.take_spilled_partition(partition) { + Ok(spilled_partition) => spilled_partition, + Err(e) => return Some(Err(e)), + } + }; + Some( + self.build_index_for_spilled_partition(spilled_partition) + .await, + ) + } + } + } + + #[cfg(test)] + pub async fn wait_for_index(&self, partition_id: u32) -> Option>> { + let cell = match self.index_cells.get(partition_id as usize) { + Some(cell) => cell, + None => { + return Some(sedona_internal_err!( + "partition_id {} exceeds {} partitions", + partition_id, + self.index_cells.len() + )) + } + }; + + get_index_from_cell(cell).await + } + + pub fn dispose_index(&self, partition_id: u32) { + if let Some(cell) = self.index_cells.get(partition_id as usize) { + cell.dispose(); + } + } + + pub fn num_loaded_indexes(&self) -> usize { + self.index_cells + .iter() + .filter(|index_cell| index_cell.is_set()) + .count() + } + + async fn build_index_for_single_partition( + &self, + build_partitions: Vec, + ) -> Result> { + let mut index_builder = SpatialIndexBuilder::new( + Arc::clone(&self.schema), + self.spatial_predicate.clone(), + self.options.clone(), + self.join_type, + self.probe_threads_count, + self.metrics.clone(), + )?; + + for build_partition in build_partitions { + let stream = build_partition.build_side_batch_stream; + let geo_statistics = build_partition.geo_statistics; + index_builder.add_stream(stream, geo_statistics).await?; + } + + let index = index_builder.finish()?; + Ok(Arc::new(index)) + } + + async fn build_index_for_spilled_partition( + &self, + spilled_partition: SpilledPartition, + ) -> Result> { + let mut index_builder = SpatialIndexBuilder::new( + Arc::clone(&self.schema), + self.spatial_predicate.clone(), + self.options.clone(), + self.join_type, + self.probe_threads_count, + self.metrics.clone(), + )?; + + // Spawn tasks to load indexed batches from spilled files concurrently + let (spill_files, geo_statistics, _) = spilled_partition.into_inner(); + let mut join_set: JoinSet> = JoinSet::new(); + let (tx, mut rx) = mpsc::channel(spill_files.len() * 2 + 1); + for spill_file in spill_files { + let tx = tx.clone(); + join_set.spawn(async move { + let result = async { + let mut stream = ExternalEvaluatedBatchStream::try_from_spill_file(spill_file)?; + while let Some(batch) = stream.next().await { + let indexed_batch = batch?; + if tx.send(Ok(indexed_batch)).await.is_err() { + return Ok(()); + } + } + Ok::<(), DataFusionError>(()) + } + .await; + if let Err(e) = result { + let _ = tx.send(Err(e)).await; + } + Ok(()) + }); + } + drop(tx); + + // Collect the loaded indexed batches and add them to the index builder + while let Some(res) = rx.recv().await { + let batch = res?; + index_builder.add_batch(batch)?; + } + + // Ensure all tasks completed successfully + while let Some(res) = join_set.join_next().await { + if let Err(e) = res { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + return Err(DataFusionError::External(Box::new(e))); + } + } + + index_builder.merge_stats(geo_statistics); + + let index = index_builder.finish()?; + Ok(Arc::new(index)) + } +} + +async fn get_index_from_cell( + cell: &DisposableAsyncCell>>, +) -> Option>> { + match cell.get().await { + Some(Ok(index)) => Some(Ok(index)), + Some(Err(shared_err)) => Some(Err(DataFusionError::Shared(shared_err))), + None => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::operand_evaluator::EvaluatedGeometryArray; + use crate::partitioning::partition_slots::PartitionSlots; + use crate::utils::bbox_sampler::BoundingBoxSamples; + use crate::{ + evaluated_batch::{ + evaluated_batch_stream::{ + in_mem::InMemoryEvaluatedBatchStream, SendableEvaluatedBatchStream, + }, + EvaluatedBatch, + }, + index::CollectBuildSideMetrics, + }; + use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion::config::SpillCompression; + use datafusion_common::{DataFusionError, Result}; + use datafusion_execution::{ + memory_pool::{GreedyMemoryPool, MemoryConsumer, MemoryPool}, + runtime_env::RuntimeEnv, + }; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; + use sedona_expr::statistics::GeoStatistics; + use sedona_functions::st_analyze_agg::AnalyzeAccumulator; + use sedona_geometry::analyze::analyze_geometry; + use sedona_schema::datatypes::WKB_GEOMETRY; + + use crate::evaluated_batch::spill::EvaluatedBatchSpillWriter; + use crate::partitioning::stream_repartitioner::{SpilledPartition, SpilledPartitions}; + use crate::spatial_predicate::{RelationPredicate, SpatialRelationType}; + + fn sample_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("geom", DataType::Binary, true), + Field::new("id", DataType::Int32, false), + ])) + } + + fn point_wkb(x: f64, y: f64) -> Vec { + let mut buf = vec![1u8, 1, 0, 0, 0]; + buf.extend_from_slice(&x.to_le_bytes()); + buf.extend_from_slice(&y.to_le_bytes()); + buf + } + + fn sample_batch(ids: &[i32], wkbs: Vec>>) -> Result { + assert_eq!(ids.len(), wkbs.len()); + let geom_values: Vec> = wkbs + .iter() + .map(|opt| opt.as_ref().map(|wkb| wkb.as_slice())) + .collect(); + let geom_array: ArrayRef = Arc::new(BinaryArray::from_opt_vec(geom_values)); + let id_array: ArrayRef = Arc::new(Int32Array::from(ids.to_vec())); + let batch = RecordBatch::try_new(sample_schema(), vec![geom_array.clone(), id_array])?; + let geom = EvaluatedGeometryArray::try_new(geom_array, &WKB_GEOMETRY)?; + Ok(EvaluatedBatch { + batch, + geom_array: geom, + }) + } + + fn predicate() -> SpatialPredicate { + SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("geom", 0)), + Arc::new(Column::new("geom", 0)), + SpatialRelationType::Intersects, + )) + } + + fn geo_stats_from_batches(batches: &[EvaluatedBatch]) -> Result { + let mut analyzer = AnalyzeAccumulator::new(WKB_GEOMETRY, WKB_GEOMETRY); + for batch in batches { + for wkb in batch.geom_array.wkbs().iter().flatten() { + let summary = + analyze_geometry(wkb).map_err(|e| DataFusionError::External(Box::new(e)))?; + analyzer.ingest_geometry_summary(&summary); + } + } + Ok(analyzer.finish()) + } + + fn new_reservation(memory_pool: Arc) -> MemoryReservation { + let consumer = MemoryConsumer::new("PartitionedIndexProviderTest"); + consumer.register(&memory_pool) + } + + fn build_partition_from_batches( + memory_pool: Arc, + batches: Vec, + ) -> Result { + let schema = batches + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| Arc::new(Schema::empty())); + let geo_statistics = geo_stats_from_batches(&batches)?; + let num_rows = batches.iter().map(|batch| batch.num_rows()).sum(); + let mut estimated_usage = 0; + for batch in &batches { + estimated_usage += batch.in_mem_size()?; + } + let stream: SendableEvaluatedBatchStream = + Box::pin(InMemoryEvaluatedBatchStream::new(schema, batches)); + Ok(BuildPartition { + num_rows, + build_side_batch_stream: stream, + geo_statistics, + bbox_samples: BoundingBoxSamples::empty(), + estimated_spatial_index_memory_usage: estimated_usage, + reservation: new_reservation(memory_pool), + metrics: CollectBuildSideMetrics::new(0, &ExecutionPlanMetricsSet::new()), + }) + } + + fn spill_partition_from_batches( + runtime_env: Arc, + batches: Vec, + ) -> Result { + if batches.is_empty() { + return Ok(SpilledPartition::empty()); + } + let schema = batches[0].schema(); + let sedona_type = batches[0].geom_array.sedona_type.clone(); + let mut writer = EvaluatedBatchSpillWriter::try_new( + runtime_env, + schema, + &sedona_type, + "partitioned-index-provider-test", + SpillCompression::Uncompressed, + SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + None, + )?; + let mut num_rows = 0; + for batch in &batches { + num_rows += batch.num_rows(); + writer.append(batch)?; + } + let geo_statistics = geo_stats_from_batches(&batches)?; + let spill_file = writer.finish()?; + Ok(SpilledPartition::new( + vec![Arc::new(spill_file)], + geo_statistics, + num_rows, + )) + } + + fn make_spilled_partitions( + runtime_env: Arc, + partitions: Vec>, + ) -> Result { + let slots = PartitionSlots::new(partitions.len()); + let mut spilled = Vec::with_capacity(slots.total_slots()); + for partition_batches in partitions { + spilled.push(spill_partition_from_batches( + Arc::clone(&runtime_env), + partition_batches, + )?); + } + spilled.push(SpilledPartition::empty()); + spilled.push(SpilledPartition::empty()); + Ok(SpilledPartitions::new(slots, spilled)) + } + + #[tokio::test] + async fn single_partition_builds_once_and_is_cached() -> Result<()> { + let memory_pool: Arc = Arc::new(GreedyMemoryPool::new(1 << 20)); + let batches = vec![sample_batch( + &[1, 2], + vec![Some(point_wkb(10.0, 10.0)), Some(point_wkb(20.0, 20.0))], + )?]; + let build_partition = build_partition_from_batches(Arc::clone(&memory_pool), batches)?; + let metrics = ExecutionPlanMetricsSet::new(); + let provider = PartitionedIndexProvider::new_single_partition( + sample_schema(), + predicate(), + SpatialJoinOptions::default(), + JoinType::Inner, + 1, + vec![build_partition], + SpatialJoinBuildMetrics::new(0, &metrics), + ); + + let first_index = provider + .build_or_wait_for_index(0) + .await + .expect("partition exists")?; + assert_eq!(first_index.indexed_batches.len(), 1); + assert_eq!(provider.num_loaded_indexes(), 1); + + let cached_index = provider + .wait_for_index(0) + .await + .expect("cached value must remain accessible")?; + assert!(Arc::ptr_eq(&first_index, &cached_index)); + Ok(()) + } + + #[tokio::test] + async fn multi_partition_concurrent_requests_share_indexes() -> Result<()> { + let memory_pool: Arc = Arc::new(GreedyMemoryPool::new(1 << 20)); + let runtime_env = Arc::new(RuntimeEnv::default()); + let partition_batches = vec![ + vec![sample_batch(&[10], vec![Some(point_wkb(0.0, 0.0))])?], + vec![sample_batch(&[20], vec![Some(point_wkb(50.0, 50.0))])?], + ]; + let spilled_partitions = make_spilled_partitions(runtime_env, partition_batches)?; + let metrics = ExecutionPlanMetricsSet::new(); + let provider = Arc::new(PartitionedIndexProvider::new_multi_partition( + sample_schema(), + predicate(), + SpatialJoinOptions::default(), + JoinType::Inner, + 1, + spilled_partitions, + SpatialJoinBuildMetrics::new(0, &metrics), + vec![new_reservation(Arc::clone(&memory_pool))], + )); + + let (idx_one, idx_two) = tokio::join!( + provider.build_or_wait_for_index(0), + provider.build_or_wait_for_index(0) + ); + let idx_one = idx_one.expect("partition exists")?; + let idx_two = idx_two.expect("partition exists")?; + assert!(Arc::ptr_eq(&idx_one, &idx_two)); + assert_eq!(idx_one.indexed_batches.len(), 1); + + let second_partition = provider + .build_or_wait_for_index(1) + .await + .expect("second partition exists")?; + assert_eq!(second_partition.indexed_batches.len(), 1); + assert_eq!(provider.num_loaded_indexes(), 2); + Ok(()) + } +} diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs b/rust/sedona-spatial-join/src/index/spatial_index.rs index 9364920ad..bff7895df 100644 --- a/rust/sedona-spatial-join/src/index/spatial_index.rs +++ b/rust/sedona-spatial-join/src/index/spatial_index.rs @@ -27,7 +27,6 @@ use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use datafusion_common::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; -use datafusion_execution::memory_pool::MemoryReservation; use float_next_after::NextAfter; use geo::BoundingRect; use geo_index::rtree::{ @@ -95,11 +94,6 @@ pub struct SpatialIndex { /// Shared KNN components (distance metrics and geometry cache) for efficient KNN queries pub(crate) knn_components: Option, - - /// Memory reservation for tracking the memory usage of the spatial index - /// Cleared on `SpatialIndex` drop - #[expect(dead_code)] - pub(crate) reservation: MemoryReservation, } impl SpatialIndex { @@ -108,7 +102,6 @@ impl SpatialIndex { schema: SchemaRef, options: SpatialJoinOptions, probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, ) -> Self { let evaluator = create_operand_evaluator(&spatial_predicate, options.clone()); let refiner = create_refiner( @@ -133,7 +126,6 @@ impl SpatialIndex { visited_build_side: None, probe_threads_counter, knn_components, - reservation, } } @@ -681,7 +673,6 @@ mod tests { use arrow_array::RecordBatch; use arrow_schema::{DataType, Field}; use datafusion_common::JoinSide; - use datafusion_execution::memory_pool::GreedyMemoryPool; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use geo_traits::Dimensions; @@ -692,7 +683,6 @@ mod tests { #[test] fn test_spatial_index_builder_empty() { - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -711,7 +701,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -724,7 +713,6 @@ mod tests { #[test] fn test_spatial_index_builder_add_batch() { - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -750,7 +738,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -779,7 +766,6 @@ mod tests { #[test] fn test_knn_query_execution_with_sample_data() { // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -807,7 +793,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -878,7 +863,6 @@ mod tests { #[test] fn test_knn_query_execution_with_different_k_values() { // Create spatial index with more data points - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -905,7 +889,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -969,7 +952,6 @@ mod tests { #[test] fn test_knn_query_execution_with_spheroid_distance() { // Create spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -996,7 +978,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1066,7 +1047,6 @@ mod tests { #[test] fn test_knn_query_execution_edge_cases() { // Create spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1093,7 +1073,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1159,7 +1138,6 @@ mod tests { #[test] fn test_knn_query_execution_empty_index() { // Create empty spatial index - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1181,7 +1159,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1207,7 +1184,6 @@ mod tests { #[test] fn test_knn_query_execution_with_tie_breakers() { // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1234,7 +1210,6 @@ mod tests { options, JoinType::Inner, 1, // probe_threads_count - memory_pool.clone(), metrics, ) .unwrap(); @@ -1322,7 +1297,6 @@ mod tests { #[test] fn test_query_knn_with_geometry_distance() { // Create a spatial index with sample geometry data - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1350,7 +1324,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1407,7 +1380,6 @@ mod tests { fn test_query_knn_with_mixed_geometries() { // Create a spatial index with complex geometries where geometry-based // distance should differ from centroid-based distance - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1435,7 +1407,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1489,7 +1460,6 @@ mod tests { #[test] fn test_query_knn_with_tie_breakers_geometry_distance() { // Create a spatial index with geometries that have identical distances for tie-breaker testing - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1516,7 +1486,6 @@ mod tests { options, JoinType::Inner, 4, - memory_pool, metrics, ) .unwrap(); @@ -1610,7 +1579,6 @@ mod tests { #[test] fn test_knn_query_with_empty_geometry() { // Create a spatial index with sample geometry data like other tests - let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024)); let options = SpatialJoinOptions { execution_mode: ExecutionMode::PrepareBuild, ..Default::default() @@ -1638,7 +1606,6 @@ mod tests { options, JoinType::Inner, 1, // probe_threads_count - memory_pool.clone(), metrics, ) .unwrap(); @@ -1687,7 +1654,6 @@ mod tests { build_geoms: &[Option<&str>], options: SpatialJoinOptions, ) -> Arc { - let memory_pool = Arc::new(GreedyMemoryPool::new(100 * 1024 * 1024)); let metrics = SpatialJoinBuildMetrics::default(); let spatial_predicate = SpatialPredicate::Relation(RelationPredicate::new( Arc::new(Column::new("left", 0)), @@ -1706,7 +1672,6 @@ mod tests { options, JoinType::Inner, 1, - memory_pool, metrics, ) .unwrap(); diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs index 9d97b539c..ca2b00880 100644 --- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs +++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs @@ -22,16 +22,15 @@ use sedona_common::SpatialJoinOptions; use sedona_expr::statistics::GeoStatistics; use datafusion_common::{utils::proxy::VecAllocExt, Result}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_expr::JoinType; use futures::StreamExt; use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex}; use parking_lot::Mutex; -use std::sync::{atomic::AtomicUsize, Arc}; +use std::sync::atomic::AtomicUsize; use crate::{ - evaluated_batch::EvaluatedBatch, - index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex, BuildPartition}, + evaluated_batch::{evaluated_batch_stream::SendableEvaluatedBatchStream, EvaluatedBatch}, + index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex}, operand_evaluator::create_operand_evaluator, refine::create_refiner, spatial_predicate::SpatialPredicate, @@ -63,8 +62,6 @@ pub struct SpatialIndexBuilder { /// Batches to be indexed indexed_batches: Vec, - /// Memory reservation for tracking the memory usage of the spatial index - reservation: MemoryReservation, /// Statistics for indexed geometries stats: GeoStatistics, @@ -99,12 +96,8 @@ impl SpatialIndexBuilder { options: SpatialJoinOptions, join_type: JoinType, probe_threads_count: usize, - memory_pool: Arc, metrics: SpatialJoinBuildMetrics, ) -> Result { - let consumer = MemoryConsumer::new("SpatialJoinIndex"); - let reservation = consumer.register(&memory_pool); - Ok(Self { schema, spatial_predicate, @@ -113,7 +106,6 @@ impl SpatialIndexBuilder { probe_threads_count, metrics, indexed_batches: Vec::new(), - reservation, stats: GeoStatistics::empty(), memory_used: 0, }) @@ -258,7 +250,6 @@ impl SpatialIndexBuilder { self.schema, self.options, AtomicUsize::new(self.probe_threads_count), - self.reservation, )); } @@ -297,6 +288,10 @@ impl SpatialIndexBuilder { } }; + log::debug!( + "Estimated memory used by spatial index: {}", + self.memory_used + ); Ok(SpatialIndex { schema: self.schema, options: self.options, @@ -309,26 +304,19 @@ impl SpatialIndexBuilder { visited_build_side, probe_threads_counter: AtomicUsize::new(self.probe_threads_count), knn_components: knn_components_opt, - reservation: self.reservation, }) } - pub async fn add_partitions(&mut self, partitions: Vec) -> Result<()> { - for partition in partitions { - self.add_partition(partition).await?; - } - Ok(()) - } - - pub async fn add_partition(&mut self, mut partition: BuildPartition) -> Result<()> { - let mut stream = partition.build_side_batch_stream; + pub async fn add_stream( + &mut self, + mut stream: SendableEvaluatedBatchStream, + geo_statistics: GeoStatistics, + ) -> Result<()> { while let Some(batch) = stream.next().await { let indexed_batch = batch?; self.add_batch(indexed_batch)?; } - self.merge_stats(partition.geo_statistics); - let mem_bytes = partition.reservation.free(); - self.reservation.try_grow(mem_bytes)?; + self.merge_stats(geo_statistics); Ok(()) } diff --git a/rust/sedona-spatial-join/src/lib.rs b/rust/sedona-spatial-join/src/lib.rs index 94af3f225..2abaf3c43 100644 --- a/rust/sedona-spatial-join/src/lib.rs +++ b/rust/sedona-spatial-join/src/lib.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -mod build_index; pub mod evaluated_batch; pub mod exec; mod index; pub mod operand_evaluator; pub mod optimizer; pub mod partitioning; +mod prepare; pub mod refine; pub mod spatial_predicate; mod stream; @@ -31,7 +31,6 @@ pub use exec::SpatialJoinExec; pub use optimizer::register_spatial_join_optimizer; // Re-export types needed for external usage (e.g., in Comet) -pub use build_index::build_index; pub use index::{SpatialIndex, SpatialJoinBuildMetrics}; pub use spatial_predicate::SpatialPredicate; diff --git a/rust/sedona-spatial-join/src/partitioning/kdb.rs b/rust/sedona-spatial-join/src/partitioning/kdb.rs index 32ac3a4c3..c09e98ff8 100644 --- a/rust/sedona-spatial-join/src/partitioning/kdb.rs +++ b/rust/sedona-spatial-join/src/partitioning/kdb.rs @@ -43,7 +43,9 @@ use std::sync::Arc; use crate::partitioning::{ - util::{bbox_to_geo_rect, rect_contains_point, rect_intersection_area, rects_intersect}, + util::{ + bbox_to_geo_rect, make_rect, rect_contains_point, rect_intersection_area, rects_intersect, + }, SpatialPartition, SpatialPartitioner, }; use datafusion_common::Result; @@ -126,9 +128,12 @@ impl KDBTree { if max_items_per_node == 0 { return sedona_internal_err!("max_items_per_node must be greater than 0"); } - let Some(extent_rect) = bbox_to_geo_rect(&extent)? else { - return sedona_internal_err!("KDBTree extent cannot be empty"); - }; + + // extent_rect is a sentinel rect if the bounding box is empty. In that case, + // almost all insertions will be ignored. We are free to partition the data + // arbitrarily when the extent is empty. + let extent_rect = bbox_to_geo_rect(&extent)?.unwrap_or(make_rect(0.0, 0.0, 0.0, 0.0)); + Ok(Self::new_with_level( max_items_per_node, max_levels, @@ -507,6 +512,13 @@ impl KDBPartitioner { } Ok(()) } + + /// Return the tree structure in human-readable format for debugging purposes. + pub fn debug_str(&self) -> String { + let mut output = String::new(); + let _ = self.debug_print(&mut output); + output + } } impl SpatialPartitioner for KDBPartitioner { @@ -966,4 +978,19 @@ mod tests { SpatialPartition::None ); } + + #[test] + fn test_kdb_partitioner_empty_extent() { + let extent = BoundingBox::empty(); + let bboxes = vec![ + BoundingBox::xy((0.0, 10.0), (0.0, 10.0)), + BoundingBox::xy((1.0, 10.0), (1.0, 10.0)), + ]; + let partitioner = KDBPartitioner::build(bboxes.clone().into_iter(), 10, 4, extent).unwrap(); + + // Partition calls should succeed + for test_bbox in bboxes { + assert!(partitioner.partition(&test_bbox).is_ok()); + } + } } diff --git a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs index 445911075..038530b1e 100644 --- a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs +++ b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs @@ -280,6 +280,13 @@ impl SpilledPartitions { } Ok(()) } + + /// Return debug info for this spilled partitions as a string. + pub fn debug_str(&self) -> String { + let mut output = String::new(); + let _ = self.debug_print(&mut output); + output + } } /// Incremental (stateful) repartitioner for an [`EvaluatedBatch`] stream. diff --git a/rust/sedona-spatial-join/src/prepare.rs b/rust/sedona-spatial-join/src/prepare.rs new file mode 100644 index 000000000..76e825b3c --- /dev/null +++ b/rust/sedona-spatial-join/src/prepare.rs @@ -0,0 +1,514 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{mem, sync::Arc}; + +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_common_runtime::JoinSet; +use datafusion_execution::{ + disk_manager::RefCountedTempFile, memory_pool::MemoryConsumer, SendableRecordBatchStream, + TaskContext, +}; +use datafusion_expr::JoinType; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use fastrand::Rng; +use sedona_common::{sedona_internal_err, NumSpatialPartitionsConfig, SedonaOptions}; +use sedona_expr::statistics::GeoStatistics; +use sedona_geometry::bounding_box::BoundingBox; + +use crate::{ + index::{ + memory_plan::{compute_memory_plan, MemoryPlan, PartitionMemorySummary}, + partitioned_index_provider::PartitionedIndexProvider, + BuildPartition, BuildSideBatchesCollector, CollectBuildSideMetrics, + SpatialJoinBuildMetrics, + }, + partitioning::{ + kdb::KDBPartitioner, + stream_repartitioner::{SpilledPartition, SpilledPartitions, StreamRepartitioner}, + PartitionedSide, SpatialPartition, SpatialPartitioner, + }, + spatial_predicate::SpatialPredicate, + utils::bbox_sampler::BoundingBoxSamples, +}; + +pub(crate) struct SpatialJoinComponents { + pub partitioned_index_provider: Arc, +} + +/// Builder for constructing `SpatialJoinComponents` from build-side streams. +/// +/// Calling `build(...)` performs the full preparation flow: +/// - collect (and spill if needed) build-side batches, +/// - compute memory plan and pick single- or multi-partition mode, +/// - repartition the build side into spatial partitions in multi-partition mode, +/// - create the appropriate `PartitionedIndexProvider` for creating spatial indexes. +pub(crate) struct SpatialJoinComponentsBuilder { + context: Arc, + build_schema: SchemaRef, + spatial_predicate: SpatialPredicate, + join_type: JoinType, + probe_threads_count: usize, + metrics: ExecutionPlanMetricsSet, + seed: u64, + sedona_options: SedonaOptions, +} + +impl SpatialJoinComponentsBuilder { + /// Create a new builder capturing the execution context and configuration + /// required to produce `SpatialJoinComponents` from build-side streams. + pub fn new( + context: Arc, + build_schema: SchemaRef, + spatial_predicate: SpatialPredicate, + join_type: JoinType, + probe_threads_count: usize, + metrics: ExecutionPlanMetricsSet, + seed: u64, + ) -> Self { + let session_config = context.session_config(); + let sedona_options = session_config + .options() + .extensions + .get::() + .cloned() + .unwrap_or_default(); + Self { + context, + build_schema, + spatial_predicate, + join_type, + probe_threads_count, + metrics, + seed, + sedona_options, + } + } + + /// Prepare and return `SpatialJoinComponents` for the given build-side + /// streams. This drives the end-to-end preparation flow and returns a + /// ready-to-use `SpatialJoinComponents` for the spatial join operator. + pub async fn build( + mut self, + build_streams: Vec, + ) -> Result { + let num_partitions = build_streams.len(); + if num_partitions == 0 { + log::debug!("Build side has no data. Creating empty spatial index."); + let partitioned_index_provider = PartitionedIndexProvider::new_empty( + self.build_schema, + self.spatial_predicate, + self.sedona_options.spatial_join, + self.join_type, + self.probe_threads_count, + SpatialJoinBuildMetrics::new(0, &self.metrics), + ); + return Ok(SpatialJoinComponents { + partitioned_index_provider: Arc::new(partitioned_index_provider), + }); + } + + let mut rng = Rng::with_seed(self.seed); + let mut build_partitions = self + .collect_build_partitions(build_streams, rng.u64(0..0xFFFF)) + .await?; + + // Determine the number of spatial partitions based on the memory reserved and the estimated amount of + // memory required for loading the entire build side into a spatial index + let memory_plan = + compute_memory_plan(build_partitions.iter().map(PartitionMemorySummary::from))?; + log::debug!("Computed memory plan for spatial join:\n{:#?}", memory_plan); + let num_partitions = match self + .sedona_options + .spatial_join + .debug + .num_spatial_partitions + { + NumSpatialPartitionsConfig::Auto => memory_plan.num_partitions, + NumSpatialPartitionsConfig::Fixed(n) => { + log::debug!("Override number of spatial partitions to {}", n); + n + } + }; + + if num_partitions == 1 { + log::debug!("Running single-partitioned in-memory spatial join"); + let partitioned_index_provider = PartitionedIndexProvider::new_single_partition( + self.build_schema, + self.spatial_predicate, + self.sedona_options.spatial_join, + self.join_type, + self.probe_threads_count, + build_partitions, + SpatialJoinBuildMetrics::new(0, &self.metrics), + ); + Ok(SpatialJoinComponents { + partitioned_index_provider: Arc::new(partitioned_index_provider), + }) + } else { + // Collect all memory reservations grown during build side collection + let mut reservations = Vec::with_capacity(build_partitions.len()); + for partition in &mut build_partitions { + reservations.push(partition.reservation.take()); + } + + // Partition the build side into multiple spatial partitions, each partition can be fully + // loaded into an in-memory spatial index + let build_partitioner = self.build_spatial_partitioner( + num_partitions, + &mut build_partitions, + rng.u64(0..0xFFFF), + )?; + let partitioned_spill_files_vec = self + .repartition_build_side(build_partitions, build_partitioner, &memory_plan) + .await?; + + let merged_spilled_partitions = merge_spilled_partitions(partitioned_spill_files_vec)?; + log::debug!( + "Build side spatial partitions:\n{}", + merged_spilled_partitions.debug_str() + ); + + // Sanity check: Multi and None partitions must be empty. All the geometries in the build side + // should fall into regular partitions + for partition in [SpatialPartition::None, SpatialPartition::Multi] { + let spilled_partition = merged_spilled_partitions.spilled_partition(partition)?; + if !spilled_partition.spill_files().is_empty() { + return sedona_internal_err!( + "Build side spatial partitions {:?} should be empty", + partition + ); + } + } + + let partitioned_index_provider = PartitionedIndexProvider::new_multi_partition( + self.build_schema, + self.spatial_predicate, + self.sedona_options.spatial_join, + self.join_type, + self.probe_threads_count, + merged_spilled_partitions, + SpatialJoinBuildMetrics::new(0, &self.metrics), + reservations, + ); + + Ok(SpatialJoinComponents { + partitioned_index_provider: Arc::new(partitioned_index_provider), + }) + } + } + + /// Collect build-side batches from the provided streams and return a + /// vector of `BuildPartition` entries representing the collected data. + /// The collector may spill to disk according to the configured options. + async fn collect_build_partitions( + &mut self, + build_streams: Vec, + seed: u64, + ) -> Result> { + let runtime_env = self.context.runtime_env(); + let session_config = self.context.session_config(); + let spill_compression = session_config.spill_compression(); + + let num_partitions = build_streams.len(); + let mut collect_metrics_vec = Vec::with_capacity(num_partitions); + let mut reservations = Vec::with_capacity(num_partitions); + let memory_pool = self.context.memory_pool(); + for k in 0..num_partitions { + let consumer = MemoryConsumer::new(format!("SpatialJoinCollectBuildSide[{k}]")) + .with_can_spill(true); + let reservation = consumer.register(memory_pool); + reservations.push(reservation); + collect_metrics_vec.push(CollectBuildSideMetrics::new(k, &self.metrics)); + } + + let collector = BuildSideBatchesCollector::new( + self.spatial_predicate.clone(), + self.sedona_options.spatial_join.clone(), + Arc::clone(&runtime_env), + spill_compression, + ); + let build_partitions = collector + .collect_all( + build_streams, + reservations, + collect_metrics_vec.clone(), + self.sedona_options + .spatial_join + .concurrent_build_side_collection, + seed, + ) + .await?; + + Ok(build_partitions) + } + + /// Construct a `SpatialPartitioner` (e.g. KDB) from collected samples so + /// the build and probe sides can be partitioned spatially across + /// `num_partitions`. + fn build_spatial_partitioner( + &self, + num_partitions: usize, + build_partitions: &mut Vec, + seed: u64, + ) -> Result> { + if matches!( + self.spatial_predicate, + SpatialPredicate::KNearestNeighbors(..) + ) { + return sedona_internal_err!("Partitioned KNN join is not supported yet"); + } + + let build_partitioner: Arc = { + // Use spatial partitioners to partition the build side and the probe side, this will + // reduce the amount of work needed for probing each partitioned index. + // The KDB partitioner is built using the collected bounding box samples. + let mut bbox_samples = BoundingBoxSamples::empty(); + let mut geo_stats = GeoStatistics::empty(); + let mut rng = Rng::with_seed(seed); + for partition in build_partitions { + let samples = mem::take(&mut partition.bbox_samples); + bbox_samples = bbox_samples.combine(samples, &mut rng); + geo_stats.merge(&partition.geo_statistics); + } + + let extent = geo_stats.bbox().cloned().unwrap_or(BoundingBox::empty()); + let mut samples = bbox_samples.take_samples(); + let max_items_per_node = 1.max(samples.len() / num_partitions); + let max_levels = num_partitions; + + log::debug!( + "Number of samples: {}, max_items_per_node: {}, max_levels: {}", + samples.len(), + max_items_per_node, + max_levels + ); + rng.shuffle(&mut samples); + let kdb_partitioner = + KDBPartitioner::build(samples.into_iter(), max_items_per_node, max_levels, extent)?; + log::debug!( + "Built KDB spatial partitioner with {} partitions", + num_partitions + ); + log::debug!( + "KDB partitioner debug info:\n{}", + kdb_partitioner.debug_str() + ); + + Arc::new(kdb_partitioner) + }; + + Ok(build_partitioner) + } + + /// Repartition the collected build-side partitions using the provided + /// `SpatialPartitioner`. Returns the spilled partitions for each spatial partition. + async fn repartition_build_side( + &self, + build_partitions: Vec, + build_partitioner: Arc, + memory_plan: &MemoryPlan, + ) -> Result> { + // Spawn each task for each build partition to repartition the data using the spatial partitioner for + // the build/indexed side + let runtime_env = self.context.runtime_env(); + let session_config = self.context.session_config(); + let target_batch_size = session_config.batch_size(); + let spill_compression = session_config.spill_compression(); + let spilled_batch_in_memory_size_threshold = if self + .sedona_options + .spatial_join + .spilled_batch_in_memory_size_threshold + == 0 + { + None + } else { + Some( + self.sedona_options + .spatial_join + .spilled_batch_in_memory_size_threshold, + ) + }; + + let memory_for_intermittent_usage = match self + .sedona_options + .spatial_join + .debug + .memory_for_intermittent_usage + { + Some(value) => { + log::debug!("Override memory for intermittent usage to {}", value); + value + } + None => memory_plan.memory_for_intermittent_usage, + }; + + let mut join_set = JoinSet::new(); + let buffer_bytes_threshold = memory_for_intermittent_usage / build_partitions.len(); + for partition in build_partitions { + let stream = partition.build_side_batch_stream; + let metrics = &partition.metrics; + let spill_metrics = metrics.spill_metrics(); + let runtime_env = Arc::clone(&runtime_env); + let partitioner = Arc::clone(&build_partitioner); + join_set.spawn(async move { + let partitioned_spill_files = StreamRepartitioner::builder( + runtime_env, + partitioner, + PartitionedSide::BuildSide, + spill_metrics, + ) + .spill_compression(spill_compression) + .buffer_bytes_threshold(buffer_bytes_threshold) + .target_batch_size(target_batch_size) + .spilled_batch_in_memory_size_threshold(spilled_batch_in_memory_size_threshold) + .build() + .repartition_stream(stream) + .await; + partitioned_spill_files + }); + } + + let results = join_set.join_all().await; + let partitioned_spill_files_vec = results.into_iter().collect::>>()?; + Ok(partitioned_spill_files_vec) + } +} + +/// Aggregate the spill files and bounds of each spatial partition collected from all build partitions +fn merge_spilled_partitions( + spilled_partitions_vec: Vec, +) -> Result { + let Some(first) = spilled_partitions_vec.first() else { + return sedona_internal_err!("spilled_partitions_vec cannot be empty"); + }; + + let slots = first.slots(); + let total_slots = slots.total_slots(); + let mut merged_spill_files: Vec>> = + (0..total_slots).map(|_| Vec::new()).collect(); + let mut partition_geo_stats: Vec = + (0..total_slots).map(|_| GeoStatistics::empty()).collect(); + let mut partition_num_rows: Vec = (0..total_slots).map(|_| 0).collect(); + + for spilled_partitions in spilled_partitions_vec { + let partitions = spilled_partitions.into_spilled_partitions()?; + for (slot_idx, partition) in partitions.into_iter().enumerate() { + let (spill_files, geo_stats, num_rows) = partition.into_inner(); + partition_geo_stats[slot_idx].merge(&geo_stats); + merged_spill_files[slot_idx].extend(spill_files); + partition_num_rows[slot_idx] += num_rows; + } + } + + let merged_partitions = merged_spill_files + .into_iter() + .zip(partition_geo_stats) + .zip(partition_num_rows) + .map(|((spill_files, geo_stats), num_rows)| { + SpilledPartition::new(spill_files, geo_stats, num_rows) + }) + .collect(); + + Ok(SpilledPartitions::new(slots, merged_partitions)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::partitioning::partition_slots::PartitionSlots; + use datafusion_execution::runtime_env::RuntimeEnv; + use sedona_geometry::interval::IntervalTrait; + + fn sample_geo_stats(bbox: (f64, f64, f64, f64), total_geometries: i64) -> GeoStatistics { + GeoStatistics::empty() + .with_bbox(Some(BoundingBox::xy((bbox.0, bbox.1), (bbox.2, bbox.3)))) + .with_total_geometries(total_geometries) + } + + fn sample_partition( + env: &Arc, + labels: &[&str], + bbox: (f64, f64, f64, f64), + total_geometries: i64, + ) -> Result { + let mut files = Vec::with_capacity(labels.len()); + for label in labels { + files.push(Arc::new(env.disk_manager.create_tmp_file(label)?)); + } + Ok(SpilledPartition::new( + files, + sample_geo_stats(bbox, total_geometries), + total_geometries as usize, + )) + } + + #[test] + fn merge_spilled_partitions_combines_files_and_stats() -> Result<()> { + let runtime_env = Arc::new(RuntimeEnv::default()); + let slots = PartitionSlots::new(2); + + let partitions_a = vec![ + sample_partition(&runtime_env, &["r0_a"], (0.0, 1.0, 0.0, 1.0), 10)?, + sample_partition(&runtime_env, &["r1_a"], (10.0, 11.0, -1.0, 1.0), 5)?, + sample_partition(&runtime_env, &["none_a"], (-5.0, -4.0, -5.0, -4.0), 2)?, + SpilledPartition::empty(), + ]; + let first = SpilledPartitions::new(slots, partitions_a); + + let partitions_b = vec![ + sample_partition(&runtime_env, &["r0_b1", "r0_b2"], (5.0, 6.0, 5.0, 6.0), 20)?, + sample_partition(&runtime_env, &[], (12.0, 13.0, 2.0, 3.0), 8)?, + SpilledPartition::empty(), + sample_partition(&runtime_env, &["multi_b"], (50.0, 51.0, 50.0, 51.0), 1)?, + ]; + let second = SpilledPartitions::new(slots, partitions_b); + + let merged = merge_spilled_partitions(vec![first, second])?; + + assert_eq!(merged.spill_file_count(), 6); + + let regular0 = merged.spilled_partition(SpatialPartition::Regular(0))?; + assert_eq!(regular0.spill_files().len(), 3); + assert_eq!(regular0.geo_statistics().total_geometries(), Some(30)); + let bbox0 = regular0.geo_statistics().bbox().unwrap(); + assert_eq!(bbox0.x().lo(), 0.0); + assert_eq!(bbox0.x().hi(), 6.0); + assert_eq!(bbox0.y().lo(), 0.0); + assert_eq!(bbox0.y().hi(), 6.0); + + let regular1 = merged.spilled_partition(SpatialPartition::Regular(1))?; + assert_eq!(regular1.spill_files().len(), 1); + assert_eq!(regular1.geo_statistics().total_geometries(), Some(13)); + let bbox1 = regular1.geo_statistics().bbox().unwrap(); + assert_eq!(bbox1.x().lo(), 10.0); + assert_eq!(bbox1.x().hi(), 13.0); + assert_eq!(bbox1.y().lo(), -1.0); + assert_eq!(bbox1.y().hi(), 3.0); + + let none_partition = merged.spilled_partition(SpatialPartition::None)?; + assert_eq!(none_partition.spill_files().len(), 1); + assert_eq!(none_partition.geo_statistics().total_geometries(), Some(2)); + + let multi_partition = merged.spilled_partition(SpatialPartition::Multi)?; + assert_eq!(multi_partition.spill_files().len(), 1); + assert_eq!(multi_partition.geo_statistics().total_geometries(), Some(1)); + + Ok(()) + } +} diff --git a/rust/sedona-spatial-join/src/stream.rs b/rust/sedona-spatial-join/src/stream.rs index 8451ff2d4..edbb41dd1 100644 --- a/rust/sedona-spatial-join/src/stream.rs +++ b/rust/sedona-spatial-join/src/stream.rs @@ -38,8 +38,10 @@ use std::sync::Arc; use crate::evaluated_batch::evaluated_batch_stream::evaluate::create_evaluated_probe_stream; use crate::evaluated_batch::evaluated_batch_stream::SendableEvaluatedBatchStream; use crate::evaluated_batch::EvaluatedBatch; +use crate::index::partitioned_index_provider::PartitionedIndexProvider; use crate::index::SpatialIndex; use crate::operand_evaluator::create_operand_evaluator; +use crate::prepare::SpatialJoinComponents; use crate::spatial_predicate::SpatialPredicate; use crate::utils::join_utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -52,6 +54,8 @@ use sedona_common::option::SpatialJoinOptions; /// Stream for producing spatial join result batches. pub(crate) struct SpatialJoinStream { + /// The partition id of the probe side stream + probe_partition_id: usize, /// Schema of joined results schema: Arc, /// join filter @@ -73,13 +77,18 @@ pub(crate) struct SpatialJoinStream { options: SpatialJoinOptions, /// Target output batch size target_output_batch_size: usize, - /// Once future for the spatial index - once_fut_spatial_index: OnceFut, - /// Once async for the spatial index, will be manually disposed by the last finished stream - /// to avoid unnecessary memory usage. - once_async_spatial_index: Arc>>>, + /// Once future for the shared partitioned index provider + once_fut_spatial_join_components: OnceFut, + /// Once async for the provider, disposed by the last finished stream + once_async_spatial_join_components: Arc>>>, + /// Cached index provider reference after it becomes available + index_provider: Option>, /// The spatial index spatial_index: Option>, + /// Pending future for building or waiting on a partitioned index + pending_index_future: Option>>>>, + /// Total number of regular partitions produced by the provider + num_regular_partitions: Option, /// The spatial predicate being evaluated spatial_predicate: SpatialPredicate, } @@ -87,6 +96,7 @@ pub(crate) struct SpatialJoinStream { impl SpatialJoinStream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( + probe_partition_id: usize, schema: Arc, on: &SpatialPredicate, filter: Option, @@ -97,8 +107,8 @@ impl SpatialJoinStream { join_metrics: SpatialJoinProbeMetrics, options: SpatialJoinOptions, target_output_batch_size: usize, - once_fut_spatial_index: OnceFut, - once_async_spatial_index: Arc>>>, + once_fut_spatial_join_components: OnceFut, + once_async_spatial_join_components: Arc>>>, ) -> Self { let evaluator = create_operand_evaluator(on, options.clone()); let probe_stream = create_evaluated_probe_stream( @@ -107,6 +117,7 @@ impl SpatialJoinStream { join_metrics.join_time.clone(), ); Self { + probe_partition_id, schema, filter, join_type, @@ -114,12 +125,15 @@ impl SpatialJoinStream { column_indices, probe_side_ordered, join_metrics, - state: SpatialJoinStreamState::WaitBuildIndex, + state: SpatialJoinStreamState::WaitPrepareSpatialJoinComponents, options, target_output_batch_size, - once_fut_spatial_index, - once_async_spatial_index, + once_fut_spatial_join_components, + once_async_spatial_join_components, + index_provider: None, spatial_index: None, + pending_index_future: None, + num_regular_partitions: None, spatial_predicate: on.clone(), } } @@ -169,6 +183,8 @@ impl SpatialJoinProbeMetrics { /// This enumeration represents various states of the nested loop join algorithm. #[allow(clippy::large_enum_variant)] pub(crate) enum SpatialJoinStreamState { + /// The initial mode: waiting for the spatial join components to become available + WaitPrepareSpatialJoinComponents, /// The initial mode: waiting for the spatial index to be built WaitBuildIndex, /// Indicates that build-side has been collected, and stream is ready for @@ -193,6 +209,9 @@ impl SpatialJoinStream { ) -> Poll>> { loop { return match &mut self.state { + SpatialJoinStreamState::WaitPrepareSpatialJoinComponents => { + handle_state!(ready!(self.wait_create_spatial_join_components(cx))) + } SpatialJoinStreamState::WaitBuildIndex => { handle_state!(ready!(self.wait_build_index(cx))) } @@ -213,16 +232,97 @@ impl SpatialJoinStream { } } - fn wait_build_index( + fn wait_create_spatial_join_components( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>>> { - let index = ready!(self.once_fut_spatial_index.get_shared(cx))?; - self.spatial_index = Some(index); - self.state = SpatialJoinStreamState::FetchProbeBatch; + if self.index_provider.is_none() { + let spatial_join_components = + ready!(self.once_fut_spatial_join_components.get_shared(cx))?; + let provider = Arc::clone(&spatial_join_components.partitioned_index_provider); + self.num_regular_partitions = Some(provider.num_regular_partitions() as u32); + self.index_provider = Some(provider); + } + + let num_partitions = self + .num_regular_partitions + .expect("num_regular_partitions should be available"); + if num_partitions == 0 { + // Usually does not happen. The indexed side should have at least 1 partition. + self.state = SpatialJoinStreamState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if num_partitions > 1 { + return Poll::Ready(sedona_internal_err!( + "Multi-partitioned spatial join is not supported yet" + )); + } + + self.state = SpatialJoinStreamState::WaitBuildIndex; Poll::Ready(Ok(StatefulStreamResult::Continue)) } + fn wait_build_index( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let num_partitions = self + .num_regular_partitions + .expect("num_regular_partitions should be available"); + let partition_id = 0; + if partition_id >= num_partitions { + self.state = SpatialJoinStreamState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if self.pending_index_future.is_none() { + let provider = Arc::clone( + self.index_provider + .as_ref() + .expect("Partitioned index provider should be available"), + ); + let future = { + log::debug!( + "[Partition {}] Building index for spatial partition {}", + self.probe_partition_id, + partition_id + ); + async move { provider.build_or_wait_for_index(partition_id).await }.boxed() + }; + self.pending_index_future = Some(future); + } + + let future = self + .pending_index_future + .as_mut() + .expect("pending future must exist"); + + match future.poll_unpin(cx) { + Poll::Ready(Some(Ok(index))) => { + self.pending_index_future = None; + self.spatial_index = Some(index); + log::debug!( + "[Partition {}] Start probing spatial partition {}", + self.probe_partition_id, + partition_id + ); + self.state = SpatialJoinStreamState::FetchProbeBatch; + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + Poll::Ready(Some(Err(err))) => { + self.pending_index_future = None; + Poll::Ready(Err(err)) + } + Poll::Ready(None) => { + self.pending_index_future = None; + self.state = SpatialJoinStreamState::Completed; + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + Poll::Pending => Poll::Pending, + } + } + fn fetch_probe_batch( &mut self, cx: &mut std::task::Context<'_>, @@ -318,8 +418,13 @@ impl SpatialJoinStream { // Drop the once async to avoid holding a long-living reference to the spatial index. // The spatial index will be dropped when this stream is dropped. - let mut once_async = self.once_async_spatial_index.lock(); + let mut once_async = self.once_async_spatial_join_components.lock(); once_async.take(); + + if let Some(provider) = self.index_provider.as_ref() { + provider.dispose_index(0); + assert!(provider.num_loaded_indexes() == 0); + } } // Initial setup for processing unmatched build batches diff --git a/rust/sedona-spatial-join/src/utils.rs b/rust/sedona-spatial-join/src/utils.rs index 42a257f0a..4d73a0024 100644 --- a/rust/sedona-spatial-join/src/utils.rs +++ b/rust/sedona-spatial-join/src/utils.rs @@ -17,6 +17,7 @@ pub(crate) mod arrow_utils; pub(crate) mod bbox_sampler; +pub(crate) mod disposable_async_cell; pub(crate) mod init_once_array; pub(crate) mod join_utils; pub(crate) mod once_fut; diff --git a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs index 498f3863a..99280162d 100644 --- a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs +++ b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![allow(unused)] use datafusion_common::{DataFusionError, Result}; use fastrand::Rng; use sedona_geometry::bounding_box::BoundingBox; diff --git a/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs new file mode 100644 index 000000000..e738e0347 --- /dev/null +++ b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt; + +use parking_lot::Mutex; +use tokio::sync::Notify; + +/// Error returned when writing to a [`DisposableAsyncCell`] fails. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum CellSetError { + /// The cell has already been disposed, so new values are rejected. + Disposed, + + /// The cell already has a value. + AlreadySet, +} + +/// An asynchronous cell that can be set at most once before either being +/// disposed or read by any number of waiters. +/// +/// This is used as a lightweight one-shot coordination primitive in the spatial +/// join implementation. For example, `PartitionedIndexProvider` keeps one +/// `DisposableAsyncCell` per regular partition to publish either a successfully +/// built `SpatialIndex` (or the build error) exactly once. Concurrent +/// `SpatialJoinStream`s racing to probe the same partition can then await the +/// same shared result instead of building duplicate indexes. +/// +/// When an index is no longer needed (e.g. the last stream finishes a +/// partition), the cell can be disposed to free resources. +/// +/// Awaiters calling [`DisposableAsyncCell::get`] will park until a value is set +/// or the cell is disposed. Once disposed, `get` returns `None` and `set` +/// returns [`CellSetError::Disposed`]. +pub(crate) struct DisposableAsyncCell { + state: Mutex>, + notify: Notify, +} + +impl fmt::Debug for DisposableAsyncCell { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DisposableAsyncCell") + } +} + +impl Default for DisposableAsyncCell { + fn default() -> Self { + Self::new() + } +} + +impl DisposableAsyncCell { + /// Creates a new empty cell with no stored value. + pub(crate) fn new() -> Self { + Self { + state: Mutex::new(CellState::Empty), + notify: Notify::new(), + } + } + + /// Marks the cell as disposed and wakes every waiter. + pub(crate) fn dispose(&self) { + { + let mut state = self.state.lock(); + *state = CellState::Disposed; + } + self.notify.notify_waiters(); + } + + /// Check whether the cell has a value or not. + pub(crate) fn is_set(&self) -> bool { + let state = self.state.lock(); + matches!(*state, CellState::Value(_)) + } + + /// Check whether the cell is empty (not set or disposed) + pub(crate) fn is_empty(&self) -> bool { + let state = self.state.lock(); + matches!(*state, CellState::Empty) + } +} + +impl DisposableAsyncCell { + /// Waits until a value is set or the cell is disposed. + /// Returns `None` if the cell is disposed without a value. + pub(crate) async fn get(&self) -> Option { + loop { + let notified = self.notify.notified(); + { + let state = self.state.lock(); + match &*state { + CellState::Value(val) => return Some(val.clone()), + CellState::Disposed => return None, + CellState::Empty => {} + } + } + notified.await; + } + } + + /// Stores the provided value if the cell is still empty. + /// Fails if a value already exists or the cell has been disposed. + pub(crate) fn set(&self, value: T) -> std::result::Result<(), CellSetError> { + { + let mut state = self.state.lock(); + match &mut *state { + CellState::Empty => *state = CellState::Value(value), + CellState::Disposed => return Err(CellSetError::Disposed), + CellState::Value(_) => return Err(CellSetError::AlreadySet), + } + } + + self.notify.notify_waiters(); + Ok(()) + } +} + +enum CellState { + Empty, + Value(T), + Disposed, +} + +#[cfg(test)] +mod tests { + use super::{CellSetError, DisposableAsyncCell}; + use std::sync::Arc; + use tokio::task; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn get_returns_value_once_set() { + let cell = DisposableAsyncCell::new(); + cell.set(42).expect("set succeeds"); + assert_eq!(Some(42), cell.get().await); + } + + #[tokio::test] + async fn multiple_waiters_receive_same_value() { + let cell = Arc::new(DisposableAsyncCell::new()); + let cloned = Arc::clone(&cell); + let waiter_one = task::spawn(async move { cloned.get().await }); + let cloned = Arc::clone(&cell); + let waiter_two = task::spawn(async move { cloned.get().await }); + + cell.set(String::from("value")).expect("set succeeds"); + assert_eq!(Some("value".to_string()), waiter_one.await.unwrap()); + assert_eq!(Some("value".to_string()), waiter_two.await.unwrap()); + } + + #[tokio::test] + async fn dispose_unblocks_waiters() { + let cell = Arc::new(DisposableAsyncCell::::new()); + let waiter = tokio::spawn({ + let cloned = Arc::clone(&cell); + async move { cloned.get().await } + }); + + cell.dispose(); + assert_eq!(None, waiter.await.unwrap()); + } + + #[tokio::test] + async fn set_after_dispose_fails() { + let cell = DisposableAsyncCell::new(); + cell.dispose(); + assert_eq!(Err(CellSetError::Disposed), cell.set(5)); + } + + #[tokio::test] + async fn set_twice_rejects_second_value() { + let cell = DisposableAsyncCell::new(); + cell.set("first").expect("initial set succeeds"); + assert_eq!(Err(CellSetError::AlreadySet), cell.set("second")); + assert_eq!(Some("first"), cell.get().await); + } + + #[tokio::test] + async fn get_waits_until_value_is_set() { + let cell = Arc::new(DisposableAsyncCell::new()); + let cloned = Arc::clone(&cell); + let waiter = tokio::spawn(async move { cloned.get().await }); + + sleep(Duration::from_millis(20)).await; + assert!(!waiter.is_finished()); + + cell.set(99).expect("set succeeds"); + assert_eq!(Some(99), waiter.await.unwrap()); + } +}