Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/sedona-spatial-join/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ futures = { workspace = true }
pin-project-lite = { workspace = true }
once_cell = { workspace = true }
parking_lot = { workspace = true }
tokio = { workspace = true }
geo = { workspace = true }
sedona-geo-generic-alg = { workspace = true }
geo-traits = { workspace = true, features = ["geo-types"] }
Expand Down
1 change: 0 additions & 1 deletion rust/sedona-spatial-join/src/build_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ pub async fn build_index(
sedona_options.spatial_join,
join_type,
probe_threads_count,
Arc::clone(memory_pool),
SpatialJoinBuildMetrics::new(0, &metrics),
)?;
index_builder.add_partitions(build_partitions).await?;
Expand Down
50 changes: 27 additions & 23 deletions rust/sedona-spatial-join/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ use parking_lot::Mutex;
use sedona_common::SpatialJoinOptions;

use crate::{
build_index::build_index,
index::SpatialIndex,
prepare::{SpatialJoinComponents, SpatialJoinComponentsBuilder},
spatial_predicate::{KNNPredicate, SpatialPredicate},
stream::{SpatialJoinProbeMetrics, SpatialJoinStream},
utils::join_utils::{asymmetric_join_output_partitioning, boundedness_from_children},
utils::once_fut::OnceAsync,
utils::{
join_utils::{asymmetric_join_output_partitioning, boundedness_from_children},
once_fut::OnceAsync,
},
SedonaOptions,
};

Expand Down Expand Up @@ -132,9 +133,10 @@ pub struct SpatialJoinExec {
column_indices: Vec<ColumnIndex>,
/// Cache holding plan properties like equivalences, output partitioning etc.
cache: PlanProperties,
/// Spatial index built asynchronously on first execute() call and shared across all partitions.
/// Uses OnceAsync for lazy initialization coordinated via async runtime.
once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
/// Once future for creating the partitioned index provider shared by all probe partitions.
/// This future runs only once before probing starts, and can be disposed by the last finished
/// stream so the provider does not outlive the execution plan unnecessarily.
once_async_spatial_join_components: Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
/// Indicates if this SpatialJoin was converted from a HashJoin
/// When true, we preserve HashJoin's equivalence properties and partitioning
converted_from_hash_join: bool,
Expand Down Expand Up @@ -203,7 +205,7 @@ impl SpatialJoinExec {
projection,
metrics: Default::default(),
cache,
once_async_spatial_index: Arc::new(Mutex::new(None)),
once_async_spatial_join_components: Arc::new(Mutex::new(None)),
converted_from_hash_join,
seed,
})
Expand Down Expand Up @@ -431,7 +433,7 @@ impl ExecutionPlan for SpatialJoinExec {
projection: self.projection.clone(),
metrics: Default::default(),
cache: self.cache.clone(),
once_async_spatial_index: Arc::new(Mutex::new(None)),
once_async_spatial_join_components: Arc::new(Mutex::new(None)),
converted_from_hash_join: self.converted_from_hash_join,
seed: self.seed,
}))
Expand Down Expand Up @@ -463,8 +465,8 @@ impl ExecutionPlan for SpatialJoinExec {
let (build_plan, probe_plan) = (&self.left, &self.right);

// Build the spatial index using shared OnceAsync
let once_fut_spatial_index = {
let mut once_async = self.once_async_spatial_index.lock();
let once_fut_spatial_join_components = {
let mut once_async = self.once_async_spatial_join_components.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
Expand All @@ -479,16 +481,16 @@ impl ExecutionPlan for SpatialJoinExec {

let probe_thread_count =
self.right.output_partitioning().partition_count();
Ok(build_index(
let spatial_join_components_builder = SpatialJoinComponentsBuilder::new(
Arc::clone(&context),
build_side.schema(),
build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
self.seed,
))
);
Ok(spatial_join_components_builder.build(build_streams))
})?
};

Expand All @@ -508,6 +510,7 @@ impl ExecutionPlan for SpatialJoinExec {
self.maintains_input_order()[1] && self.right.output_ordering().is_some();

Ok(Box::pin(SpatialJoinStream::new(
partition,
self.schema(),
&self.on,
self.filter.clone(),
Expand All @@ -518,8 +521,8 @@ impl ExecutionPlan for SpatialJoinExec {
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
once_fut_spatial_index,
Arc::clone(&self.once_async_spatial_index),
once_fut_spatial_join_components,
Arc::clone(&self.once_async_spatial_join_components),
)))
}
}
Expand Down Expand Up @@ -556,8 +559,8 @@ impl SpatialJoinExec {
let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), self.left.as_ref());

// Build the spatial index
let once_fut_spatial_index = {
let mut once_async = self.once_async_spatial_index.lock();
let once_fut_spatial_join_components = {
let mut once_async = self.once_async_spatial_join_components.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
Expand All @@ -571,16 +574,16 @@ impl SpatialJoinExec {
}

let probe_thread_count = probe_plan.output_partitioning().partition_count();
Ok(build_index(
let spatial_join_components_builder = SpatialJoinComponentsBuilder::new(
Arc::clone(&context),
build_side.schema(),
build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
self.seed,
))
);
Ok(spatial_join_components_builder.build(build_streams))
})?
};

Expand All @@ -605,6 +608,7 @@ impl SpatialJoinExec {
};

Ok(Box::pin(SpatialJoinStream::new(
partition,
self.schema(),
&self.on,
self.filter.clone(),
Expand All @@ -615,8 +619,8 @@ impl SpatialJoinExec {
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
once_fut_spatial_index,
Arc::clone(&self.once_async_spatial_index),
once_fut_spatial_join_components,
Arc::clone(&self.once_async_spatial_join_components),
)))
}
}
Expand Down
2 changes: 2 additions & 0 deletions rust/sedona-spatial-join/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

pub(crate) mod build_side_collector;
mod knn_adapter;
pub(crate) mod memory_plan;
pub(crate) mod partitioned_index_provider;
pub(crate) mod spatial_index;
pub(crate) mod spatial_index_builder;

Expand Down
29 changes: 20 additions & 9 deletions rust/sedona-spatial-join/src/index/build_side_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ pub(crate) struct BuildPartition {
/// The size of this reservation will be used to determine the maximum size of
/// each spatial partition, as well as how many spatial partitions to create.
pub reservation: MemoryReservation,

/// Metrics collected during the build side collection phase
pub metrics: CollectBuildSideMetrics,
}

/// A collector for evaluating the spatial expression on build side batches and collect
Expand Down Expand Up @@ -112,6 +115,10 @@ impl CollectBuildSideMetrics {
spill_metrics: SpillMetrics::new(metrics, partition),
}
}

pub fn spill_metrics(&self) -> SpillMetrics {
self.spill_metrics.clone()
}
}

impl BuildSideBatchesCollector {
Expand Down Expand Up @@ -147,7 +154,7 @@ impl BuildSideBatchesCollector {
mut stream: SendableEvaluatedBatchStream,
mut reservation: MemoryReservation,
mut bbox_sampler: BoundingBoxSampler,
metrics: &CollectBuildSideMetrics,
metrics: CollectBuildSideMetrics,
) -> Result<BuildPartition> {
let mut spill_writer_opt = None;
let mut in_mem_batches: Vec<EvaluatedBatch> = Vec::new();
Expand Down Expand Up @@ -200,7 +207,7 @@ impl BuildSideBatchesCollector {
e,
);
spill_writer_opt =
self.spill_in_mem_batches(&mut in_mem_batches, metrics)?;
self.spill_in_mem_batches(&mut in_mem_batches, &metrics)?;
}
}
Some(spill_writer) => {
Expand Down Expand Up @@ -236,7 +243,7 @@ impl BuildSideBatchesCollector {
"Force spilling enabled. Spilling {} in-memory batches to disk.",
in_mem_batches.len()
);
spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, metrics)?;
spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches, &metrics)?;
}

let build_side_batch_stream: SendableEvaluatedBatchStream = match spill_writer_opt {
Expand Down Expand Up @@ -266,6 +273,7 @@ impl BuildSideBatchesCollector {
bbox_samples: bbox_sampler.into_samples(),
estimated_spatial_index_memory_usage,
reservation,
metrics,
})
}

Expand Down Expand Up @@ -329,7 +337,7 @@ impl BuildSideBatchesCollector {
let evaluated_stream =
create_evaluated_build_stream(stream, evaluator, metrics.time_taken.clone());
let result = collector
.collect(evaluated_stream, reservation, bbox_sampler, &metrics)
.collect(evaluated_stream, reservation, bbox_sampler, metrics)
.await;
(partition_id, result)
});
Expand Down Expand Up @@ -378,7 +386,7 @@ impl BuildSideBatchesCollector {
let evaluated_stream =
create_evaluated_build_stream(stream, evaluator, metrics.time_taken.clone());
let result = self
.collect(evaluated_stream, reservation, bbox_sampler, &metrics)
.collect(evaluated_stream, reservation, bbox_sampler, metrics)
.await?;
results.push(result);
}
Expand Down Expand Up @@ -534,11 +542,12 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);

let partition = collector
.collect(stream, reservation, sampler, &metrics)
.collect(stream, reservation, sampler, metrics)
.await?;
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
let metrics = &partition.metrics;
assert!(!is_external, "Expected in-memory batches");
assert_eq!(collect_ids(&batches), vec![0, 1, 2]);
assert_eq!(partition.num_rows, 3);
Expand All @@ -564,14 +573,15 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);

let partition = collector
.collect(stream, reservation, sampler, &metrics)
.collect(stream, reservation, sampler, metrics)
.await?;
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
let metrics = &partition.metrics;
assert!(is_external, "Expected batches to spill to disk");
assert_eq!(collect_ids(&batches), vec![10, 11, 12]);
let spill_metrics = metrics.spill_metrics;
let spill_metrics = metrics.spill_metrics();
assert!(spill_metrics.spill_file_count.value() >= 1);
assert!(spill_metrics.spilled_rows.value() >= 1);
Ok(())
Expand All @@ -587,12 +597,13 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);

let partition = collector
.collect(stream, reservation, sampler, &metrics)
.collect(stream, reservation, sampler, metrics)
.await?;
assert_eq!(partition.num_rows, 0);
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
let metrics = &partition.metrics;
assert!(!is_external);
assert!(batches.is_empty());
assert_eq!(metrics.num_batches.value(), 0);
Expand Down
Loading