diff --git a/Cargo.lock b/Cargo.lock index c265c6593..17aaea60f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5390,6 +5390,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "criterion", "datafusion", "datafusion-common", @@ -5398,6 +5399,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "datafusion-physical-plan", + "env_logger 0.11.8", "fastrand", "float_next_after", "futures", diff --git a/python/sedonadb/tests/test_sjoin.py b/python/sedonadb/tests/test_sjoin.py index ae7bb27ba..438d6f99b 100644 --- a/python/sedonadb/tests/test_sjoin.py +++ b/python/sedonadb/tests/test_sjoin.py @@ -391,6 +391,7 @@ def test_spatial_join_with_pandas_metadata(con): SELECT p.idx FROM points AS p, polygons AS poly WHERE ST_Intersects(p.geometry, poly.geometry) + ORDER BY p.idx """ res = con.sql(query).to_pandas() diff --git a/rust/sedona-spatial-join/Cargo.toml b/rust/sedona-spatial-join/Cargo.toml index d34f7a6cb..7a193c32a 100644 --- a/rust/sedona-spatial-join/Cargo.toml +++ b/rust/sedona-spatial-join/Cargo.toml @@ -34,6 +34,7 @@ result_large_err = "allow" backtrace = ["datafusion-common/backtrace"] [dependencies] +async-trait = { workspace = true } arrow = { workspace = true } arrow-schema = { workspace = true } arrow-array = { workspace = true } @@ -77,6 +78,7 @@ sedona-testing = { workspace = true} wkt = { workspace = true } tokio = { workspace = true, features = ["macros"] } rand = { workspace = true } +env_logger = { workspace = true } [[bench]] name = "kdb" diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index 2ed90d735..dac7422cb 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -19,28 +19,28 @@ use std::{fmt::Formatter, sync::Arc}; use arrow_schema::SchemaRef; use datafusion_common::{project_schema, JoinSide, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::{ - equivalence::{join_equivalence_properties, ProjectionMapping}, - expressions::{BinaryExpr, Column}, - PhysicalExpr, -}; +use datafusion_expr::JoinType; +use datafusion_physical_expr::equivalence::{join_equivalence_properties, ProjectionMapping}; use datafusion_physical_plan::{ - execution_plan::EmissionType, + common::can_project, joins::utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter}, + joins::utils::{reorder_output_after_swap, swap_join_projection}, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, - PlanProperties, + projection::{try_embed_projection, EmbeddedProjection, ProjectionExec}, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; use parking_lot::Mutex; use sedona_common::{sedona_internal_err, SpatialJoinOptions}; use crate::{ prepare::{SpatialJoinComponents, SpatialJoinComponentsBuilder}, - spatial_predicate::{KNNPredicate, SpatialPredicate}, + spatial_predicate::{KNNPredicate, SpatialPredicate, SpatialPredicateTrait}, stream::{SpatialJoinProbeMetrics, SpatialJoinStream}, utils::{ - join_utils::{asymmetric_join_output_partitioning, boundedness_from_children}, + join_utils::{ + asymmetric_join_output_partitioning, boundedness_from_children, + compute_join_emission_type, try_pushdown_through_join, JoinPushdownData, + }, once_fut::OnceAsync, }, SedonaOptions, @@ -49,28 +49,6 @@ use crate::{ /// Type alias for build and probe execution plans type BuildProbePlans<'a> = (&'a Arc, &'a Arc); -/// Extract equality join conditions from a JoinFilter -/// Returns column pairs that represent equality conditions as PhysicalExprs -fn extract_equality_conditions( - filter: &JoinFilter, -) -> Vec<(Arc, Arc)> { - let mut equalities = Vec::new(); - - if let Some(binary_expr) = filter.expression().as_any().downcast_ref::() { - if binary_expr.op() == &Operator::Eq { - // Check if both sides are column references - if let (Some(_left_col), Some(_right_col)) = ( - binary_expr.left().as_any().downcast_ref::(), - binary_expr.right().as_any().downcast_ref::(), - ) { - equalities.push((binary_expr.left().clone(), binary_expr.right().clone())); - } - } - } - - equalities -} - /// Determine the correct build/probe execution plan assignment for KNN joins. /// /// For KNN joins, we need to determine which execution plan should be used as the build side @@ -87,7 +65,6 @@ fn determine_knn_build_probe_plans<'a>( knn_pred: &KNNPredicate, left_plan: &'a Arc, right_plan: &'a Arc, - _join_schema: &SchemaRef, ) -> Result> { // Use the probe_side information from the optimizer to determine build/probe assignment match knn_pred.probe_side { @@ -135,9 +112,6 @@ pub struct SpatialJoinExec { /// This future runs only once before probing starts, and can be disposed by the last finished /// stream so the provider does not outlive the execution plan unnecessarily. once_async_spatial_join_components: Arc>>>, - /// Indicates if this SpatialJoin was converted from a HashJoin - /// When true, we preserve HashJoin's equivalence properties and partitioning - converted_from_hash_join: bool, /// A random seed for making random procedures in spatial join deterministic seed: u64, } @@ -153,22 +127,23 @@ impl SpatialJoinExec { projection: Option>, options: &SpatialJoinOptions, ) -> Result { - Self::try_new_with_options( - left, right, on, filter, join_type, projection, options, false, - ) + let seed = options + .debug + .random_seed + .unwrap_or(fastrand::u64(0..0xFFFF)); + Self::try_new_internal(left, right, on, filter, join_type, projection, seed) } /// Create a new SpatialJoinExec with additional options #[allow(clippy::too_many_arguments)] - pub fn try_new_with_options( + pub fn try_new_internal( left: Arc, right: Arc, on: SpatialPredicate, filter: Option, join_type: &JoinType, projection: Option>, - options: &SpatialJoinOptions, - converted_from_hash_join: bool, + seed: u64, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); @@ -183,15 +158,8 @@ impl SpatialJoinExec { Arc::clone(&join_schema), *join_type, projection.as_ref(), - filter.as_ref(), - converted_from_hash_join, )?; - let seed = options - .debug - .random_seed - .unwrap_or(fastrand::u64(0..0xFFFF)); - Ok(SpatialJoinExec { left, right, @@ -204,7 +172,6 @@ impl SpatialJoinExec { metrics: Default::default(), cache, once_async_spatial_join_components: Arc::new(Mutex::new(None)), - converted_from_hash_join, seed, }) } @@ -214,44 +181,81 @@ impl SpatialJoinExec { &self.join_type } - /// Returns a vector indicating whether the left and right inputs maintain their order. - /// The first element corresponds to the left input, and the second to the right. - /// - /// The left (build-side) input's order may change, but the right (probe-side) input's - /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins. + /// Does this join has a projection on the joined columns + pub fn contains_projection(&self) -> bool { + self.projection.is_some() + } + + /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left + /// and right inputs swapped. /// - /// Maintaining the right input's order helps optimize the nodes down the pipeline - /// (See [`ExecutionPlan::maintains_input_order`]). + /// # Notes: /// - /// This is a separate method because it is also called when computing properties, before - /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as - /// opposed to `Self`, for the same reason. - fn maintains_input_order(join_type: JoinType) -> Vec { - vec![ - false, - matches!( - join_type, - JoinType::Inner | JoinType::Right | JoinType::RightAnti | JoinType::RightSemi - ), - ] + /// This function should be called BEFORE inserting any repartitioning + /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`] + /// for more details. + pub fn swap_inputs(&self) -> Result> { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + + let swapped_on = self.on.swap_for_swapped_children(); + + let swapped_projection = swap_join_projection( + left_schema.fields().len(), + right_schema.fields().len(), + self.projection.as_ref(), + &self.join_type, + ); + + let swapped_join = SpatialJoinExec::try_new_internal( + Arc::clone(&self.right), + Arc::clone(&self.left), + swapped_on, + self.filter.as_ref().map(|f| f.swap()), + &self.join_type.swap(), + swapped_projection, + self.seed, + )?; + + let swapped_join: Arc = Arc::new(swapped_join); + + match self.join_type { + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark => Ok(swapped_join), + _ if self.contains_projection() => Ok(swapped_join), + _ => { + reorder_output_after_swap(swapped_join, left_schema.as_ref(), right_schema.as_ref()) + } + } } - /// Does this join has a projection on the joined columns - pub fn contains_projection(&self) -> bool { - self.projection.is_some() + pub fn with_projection(&self, projection: Option>) -> Result { + // check if the projection is valid + can_project(&self.schema(), projection.as_ref())?; + let projection = match projection { + Some(projection) => match &self.projection { + Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), + None => Some(projection), + }, + None => None, + }; + SpatialJoinExec::try_new_internal( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.on.clone(), + self.filter.clone(), + &self.join_type, + projection, + self.seed, + ) } /// This function creates the cache object that stores the plan properties such as schema, /// equivalence properties, ordering, partitioning, etc. - /// - /// NOTICE: The implementation of this function should be identical to the one in - /// [`datafusion_physical_plan::physical_plan::join::NestedLoopJoinExec::compute_properties`]. - /// This is because SpatialJoinExec is transformed from NestedLoopJoinExec in physical plan - /// optimization phase. If the properties are not the same, the plan will be incorrect. - /// - /// When converted from HashJoin, we preserve HashJoin's equivalence properties by extracting - /// equality conditions from the filter. - #[allow(clippy::too_many_arguments)] fn compute_properties( left: &Arc, right: &Arc, @@ -259,16 +263,7 @@ impl SpatialJoinExec { schema: SchemaRef, join_type: JoinType, projection: Option<&Vec>, - filter: Option<&JoinFilter>, - converted_from_hash_join: bool, ) -> Result { - // Extract equality conditions from filter if this was converted from HashJoin - let on_columns = if converted_from_hash_join { - filter.map_or(vec![], extract_equality_conditions) - } else { - vec![] - }; - let mut eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), right.equivalence_properties().clone(), @@ -277,73 +272,29 @@ impl SpatialJoinExec { &[false, false], None, // Pass extracted equality conditions to preserve equivalences - &on_columns, - ); + &[], + )?; - // Use symmetric partitioning (like HashJoin) when converted from HashJoin - // Otherwise use asymmetric partitioning (like NestedLoopJoin) - let mut output_partitioning = if let SpatialPredicate::KNearestNeighbors(knn) = on { - match knn.probe_side { - JoinSide::Left => left.output_partitioning().clone(), - JoinSide::Right => right.output_partitioning().clone(), - _ => asymmetric_join_output_partitioning(left, right, &join_type)?, - } - } else if converted_from_hash_join { - // Replicate HashJoin's symmetric partitioning logic - // HashJoin preserves partitioning from both sides for inner joins - // and from one side for outer joins - - match join_type { - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - left.output_partitioning().clone() - } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Full => { - // For full outer join, we can't preserve partitioning - Partitioning::UnknownPartitioning(left.output_partitioning().partition_count()) - } - _ => asymmetric_join_output_partitioning(left, right, &join_type)?, - } + let probe_side = if let SpatialPredicate::KNearestNeighbors(knn) = on { + knn.probe_side } else { - asymmetric_join_output_partitioning(left, right, &join_type)? + JoinSide::Right }; + let mut output_partitioning = + asymmetric_join_output_partitioning(left, right, &join_type, probe_side)?; if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, &schema)?; let out_schema = project_schema(&schema, Some(projection))?; - let eq_props = eq_properties?; - output_partitioning = output_partitioning.project(&projection_mapping, &eq_props); - eq_properties = Ok(eq_props.project(&projection_mapping, out_schema)); + output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); + eq_properties = eq_properties.project(&projection_mapping, out_schema); } - let emission_type = if left.boundedness().is_unbounded() { - EmissionType::Final - } else if right.pipeline_behavior() == EmissionType::Incremental { - match join_type { - // If we only need to generate matched rows from the probe side, - // we can emit rows incrementally. - JoinType::Inner - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, - // If we need to generate unmatched rows from the *build side*, - // we need to emit them at the end. - JoinType::Left - | JoinType::LeftAnti - | JoinType::LeftMark - | JoinType::RightMark - | JoinType::Full => EmissionType::Both, - } - } else { - right.pipeline_behavior() - }; + let emission_type = compute_join_emission_type(left, right, join_type, probe_side); Ok(PlanProperties::new( - eq_properties?, + eq_properties, output_partitioning, emission_type, boundedness_from_children([left, right]), @@ -409,32 +360,70 @@ impl ExecutionPlan for SpatialJoinExec { } fn maintains_input_order(&self) -> Vec { - Self::maintains_input_order(self.join_type) + vec![false, false] } fn children(&self) -> Vec<&Arc> { vec![&self.left, &self.right] } + /// Tries to push `projection` down through `SpatialJoinExec`. If possible, performs the + /// pushdown and returns a new [`SpatialJoinExec`] as the top plan which has projections + /// as its children. Otherwise, returns `None`. + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + // TODO: currently if there is projection in SpatialJoinExec, we can't push down projection to + // left or right input. Maybe we can pushdown the mixed projection later. + // This restriction is inherited from NestedLoopJoinExec and HashJoinExec in DataFusion. + if self.contains_projection() { + return Ok(None); + } + + if let Some(JoinPushdownData { + projected_left_child, + projected_right_child, + join_filter, + join_on, + }) = try_pushdown_through_join( + projection, + &self.left, + &self.right, + &self.join_schema, + self.join_type, + self.filter.as_ref(), + &self.on, + )? { + let new_exec = SpatialJoinExec::try_new_internal( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_on, + join_filter, + &self.join_type, + None, + self.seed, + )?; + Ok(Some(Arc::new(new_exec))) + } else { + try_embed_projection(projection, self) + } + } + fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(SpatialJoinExec { - left: children[0].clone(), - right: children[1].clone(), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: self.join_schema.clone(), - column_indices: self.column_indices.clone(), - projection: self.projection.clone(), - metrics: Default::default(), - cache: self.cache.clone(), - once_async_spatial_join_components: Arc::new(Mutex::new(None)), - converted_from_hash_join: self.converted_from_hash_join, - seed: self.seed, - })) + let new_exec = SpatialJoinExec::try_new_internal( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.on.clone(), + self.filter.clone(), + &self.join_type, + self.projection.clone(), + self.seed, + )?; + Ok(Arc::new(new_exec)) } fn metrics(&self) -> Option { @@ -448,86 +437,97 @@ impl ExecutionPlan for SpatialJoinExec { ) -> Result { match &self.on { SpatialPredicate::KNearestNeighbors(_) => self.execute_knn(partition, context), - _ => { - // Regular spatial join logic - standard left=build, right=probe semantics - let session_config = context.session_config(); - let target_output_batch_size = session_config.options().execution.batch_size; - let sedona_options = session_config - .options() - .extensions - .get::() - .cloned() - .unwrap_or_default(); - - // Regular join semantics: left is build, right is probe - let (build_plan, probe_plan) = (&self.left, &self.right); - - // Build the spatial index using shared OnceAsync - let once_fut_spatial_join_components = { - let mut once_async = self.once_async_spatial_join_components.lock(); - once_async - .get_or_insert(OnceAsync::default()) - .try_once(|| { - let build_side = build_plan; - - let num_partitions = build_side.output_partitioning().partition_count(); - let mut build_streams = Vec::with_capacity(num_partitions); - for k in 0..num_partitions { - let stream = build_side.execute(k, Arc::clone(&context))?; - build_streams.push(stream); - } - - let probe_thread_count = - self.right.output_partitioning().partition_count(); - let spatial_join_components_builder = SpatialJoinComponentsBuilder::new( - Arc::clone(&context), - build_side.schema(), - self.on.clone(), - self.join_type, - probe_thread_count, - self.metrics.clone(), - self.seed, - ); - Ok(spatial_join_components_builder.build(build_streams)) - })? - }; - - let column_indices_after_projection = match &self.projection { - Some(projection) => projection - .iter() - .map(|i| self.column_indices[*i].clone()) - .collect(), - None => self.column_indices.clone(), - }; - - let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics); - let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?; - - // For regular joins: probe is right side (index 1) - let probe_side_ordered = - self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - - Ok(Box::pin(SpatialJoinStream::new( - partition, - self.schema(), - &self.on, - self.filter.clone(), - self.join_type, - probe_stream, - column_indices_after_projection, - probe_side_ordered, - join_metrics, - sedona_options.spatial_join, - target_output_batch_size, - once_fut_spatial_join_components, - Arc::clone(&self.once_async_spatial_join_components), - ))) - } + _ => self.execute(partition, context), } } } +impl EmbeddedProjection for SpatialJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + impl SpatialJoinExec { + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Regular spatial join logic - standard left=build, right=probe semantics + let session_config = context.session_config(); + let target_output_batch_size = session_config.options().execution.batch_size; + let sedona_options = session_config + .options() + .extensions + .get::() + .cloned() + .unwrap_or_default(); + + // Regular join semantics: left is build, right is probe + let (build_plan, probe_plan) = (&self.left, &self.right); + + // Build the spatial join components using shared OnceAsync + let once_fut_spatial_join_components = { + let mut once_async = self.once_async_spatial_join_components.lock(); + once_async + .get_or_insert(OnceAsync::default()) + .try_once(|| { + let build_side = build_plan; + + let num_partitions = build_side.output_partitioning().partition_count(); + let mut build_streams = Vec::with_capacity(num_partitions); + for k in 0..num_partitions { + let stream = build_side.execute(k, Arc::clone(&context))?; + build_streams.push(stream); + } + + let probe_thread_count = probe_plan.output_partitioning().partition_count(); + let spatial_join_components_builder = SpatialJoinComponentsBuilder::new( + Arc::clone(&context), + build_side.schema(), + self.on.clone(), + self.join_type, + probe_thread_count, + self.metrics.clone(), + self.seed, + ); + Ok(spatial_join_components_builder.build(build_streams)) + })? + }; + + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics); + let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?; + + // For regular joins: probe is right side (index 1) + let probe_side_ordered = + self.maintains_input_order()[1] && self.right.output_ordering().is_some(); + + Ok(Box::pin(SpatialJoinStream::new( + partition, + self.schema(), + &self.on, + self.filter.clone(), + self.join_type, + probe_stream, + column_indices_after_projection, + probe_side_ordered, + join_metrics, + sedona_options.spatial_join, + target_output_batch_size, + once_fut_spatial_join_components, + Arc::clone(&self.once_async_spatial_join_components), + ))) + } + /// Execute KNN (K-Nearest Neighbors) spatial join with specialized logic for asymmetric KNN semantics fn execute_knn( &self, @@ -551,7 +551,7 @@ impl SpatialJoinExec { // Determine which execution plan should be build vs probe using join schema analysis let (build_plan, probe_plan) = - determine_knn_build_probe_plans(knn_pred, &self.left, &self.right, &self.join_schema)?; + determine_knn_build_probe_plans(knn_pred, &self.left, &self.right)?; // Determine if probe plan is the left execution plan (for column index swapping logic) let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), self.left.as_ref()); @@ -644,7 +644,6 @@ mod tests { use sedona_testing::datagen::RandomPartitionedDataBuilder; use tokio::sync::OnceCell; - use crate::register_spatial_join_optimizer; use sedona_common::{ option::{add_sedona_option_extension, ExecutionMode, SpatialJoinOptions}, SpatialLibrary, @@ -744,13 +743,16 @@ mod tests { options: Option, batch_size: usize, ) -> Result { + let _ = env_logger::builder().is_test(true).try_init(); let mut session_config = SessionConfig::from_env()? .with_information_schema(true) .with_batch_size(batch_size); session_config = add_sedona_option_extension(session_config); let mut state_builder = SessionStateBuilder::new(); if let Some(options) = options { - state_builder = register_spatial_join_optimizer(state_builder); + // Logical rewrite (Filter(CrossJoin)->Join(filter)) + extension-based planning + // (Join(filter)->SpatialJoinExec). Intentionally avoid physical plan rewrites. + state_builder = crate::register_planner(state_builder); let opts = session_config .options_mut() .extensions @@ -1538,3 +1540,279 @@ mod tests { Ok(()) } } + +#[cfg(test)] +mod exec_transform_tests { + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; + use datafusion_physical_plan::ExecutionPlan; + + use sedona_common::{sedona_internal_err, SpatialJoinOptions}; + + use super::*; + use crate::spatial_predicate::{RelationPredicate, SpatialRelationType}; + + fn make_schema(fields: &[(&str, DataType)]) -> SchemaRef { + Arc::new(Schema::new( + fields + .iter() + .map(|(name, dt)| Field::new(*name, dt.clone(), true)) + .collect::>(), + )) + } + + fn proj_expr( + schema: &SchemaRef, + index: usize, + ) -> (Arc, String) { + let name = schema.field(index).name().to_string(); + (Arc::new(Column::new(&name, index)), name) + } + + fn collect_spatial_join_exec(plan: &Arc) -> Result> { + let mut spatial_join_execs = Vec::new(); + plan.apply(|node| { + if let Some(spatial_join_exec) = node.as_any().downcast_ref::() { + spatial_join_execs.push(spatial_join_exec); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(spatial_join_execs) + } + + #[test] + fn test_mark_join_projection_pushdown_is_graceful() -> Result<()> { + let left_schema = make_schema(&[("l", DataType::Int32)]); + let right_schema = make_schema(&[("r", DataType::Int32)]); + let left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let right: Arc = Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l", 0)), + Arc::new(Column::new("r", 0)), + SpatialRelationType::Intersects, + )); + + let join = SpatialJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftMark, + None, + &SpatialJoinOptions::default(), + )?; + + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("mark", 1)), + alias: "mark".to_string(), + }], + Arc::new(join), + )?; + + let swapped = projection + .input() + .try_swapping_with_projection(&projection)?; + assert!(swapped.is_some()); + + Ok(()) + } + + #[test] + fn test_try_swapping_with_projection_pushes_down_and_rewrites_relation_predicate() -> Result<()> + { + // left: [l0, l1, l2], right: [r0, r1] + let left_schema = make_schema(&[ + ("l0", DataType::Int32), + ("l1", DataType::Int32), + ("l2", DataType::Int32), + ]); + let right_schema = make_schema(&[("r0", DataType::Int32), ("r1", DataType::Int32)]); + let left_len = left_schema.fields().len(); + + let left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let right: Arc = Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + // on: ST_Intersects(l2, r1) (types don't matter for rewrite-only test) + let on = SpatialPredicate::Relation(RelationPredicate { + left: Arc::new(Column::new("l2", 2)), + right: Arc::new(Column::new("r1", 1)), + relation_type: SpatialRelationType::Intersects, + }); + + let exec = Arc::new(SpatialJoinExec::try_new_internal( + left, + right, + on, + None, + &JoinType::Inner, + None, + 0, + )?); + + // Project only columns used by the predicate: l2 then r1. + let join_schema = exec.schema(); + let exprs = vec![ + proj_expr(&join_schema, 2), + proj_expr(&join_schema, left_len + 1), + ]; + let proj = ProjectionExec::try_new(exprs, Arc::clone(&exec) as Arc)?; + + let Some(new_plan) = exec.try_swapping_with_projection(&proj)? else { + return sedona_internal_err!("expected try_swapping_with_projection to succeed"); + }; + + let new_exec = new_plan + .as_any() + .downcast_ref::() + .expect("expected SpatialJoinExec"); + + // Projection is pushed down into children; join has no embedded projection. + assert!(!new_exec.contains_projection()); + assert!(new_exec + .children() + .iter() + .all(|c| c.as_any().downcast_ref::().is_some())); + + // Predicate columns should be remapped to match the projected children (both become 0). + let SpatialPredicate::Relation(new_on) = &new_exec.on else { + return sedona_internal_err!("expected Relation predicate"); + }; + let new_left = new_on + .left + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + let new_right = new_on + .right + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + assert_eq!(new_left.index(), 0); + assert_eq!(new_right.index(), 0); + + Ok(()) + } + + #[test] + fn test_try_swapping_with_projection_pushes_down_and_rewrites_knn_predicate_by_probe_side( + ) -> Result<()> { + // left: [l0, lgeom], right: [r0, rgeom] + let left_schema = make_schema(&[("l0", DataType::Int32), ("lgeom", DataType::Binary)]); + let right_schema = make_schema(&[("r0", DataType::Int32), ("rgeom", DataType::Binary)]); + let left_len = left_schema.fields().len(); + + let left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let right: Arc = Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + // KNN where queries are on the RIGHT plan (probe_side=Right): ST_KNN(rgeom, lgeom, ...) + let on = SpatialPredicate::KNearestNeighbors(KNNPredicate { + left: Arc::new(Column::new("rgeom", 1)), + right: Arc::new(Column::new("lgeom", 1)), + k: 3, + use_spheroid: false, + probe_side: JoinSide::Right, + }); + + let exec = Arc::new(SpatialJoinExec::try_new_internal( + left, + right, + on, + None, + &JoinType::Inner, + None, + 0, + )?); + + // Project only geometry columns (left then right) so pushdown is allowed. + let join_schema = exec.schema(); + let exprs = vec![ + proj_expr(&join_schema, 1), + proj_expr(&join_schema, left_len + 1), + ]; + let proj = ProjectionExec::try_new(exprs, Arc::clone(&exec) as Arc)?; + + let Some(new_plan) = exec.try_swapping_with_projection(&proj)? else { + return sedona_internal_err!("expected try_swapping_with_projection to succeed"); + }; + let new_exec = new_plan + .as_any() + .downcast_ref::() + .expect("expected SpatialJoinExec"); + + let SpatialPredicate::KNearestNeighbors(new_on) = &new_exec.on else { + return sedona_internal_err!("expected KNN predicate"); + }; + + // Both sides should be remapped to 0 in their respective projected children. + let new_probe = new_on + .left + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + let new_build = new_on + .right + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + assert_eq!(new_probe.index(), 0); + assert_eq!(new_build.index(), 0); + assert_eq!(new_on.probe_side, JoinSide::Right); + + Ok(()) + } + + #[test] + fn test_swap_inputs_flips_knn_probe_side_without_swapping_exprs() -> Result<()> { + let left_schema = make_schema(&[("l0", DataType::Int32), ("lgeom", DataType::Binary)]); + let right_schema = make_schema(&[("r0", DataType::Int32), ("rgeom", DataType::Binary)]); + + let left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let right: Arc = Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let on = SpatialPredicate::KNearestNeighbors(KNNPredicate { + left: Arc::new(Column::new("rgeom", 1)), + right: Arc::new(Column::new("lgeom", 1)), + k: 3, + use_spheroid: false, + probe_side: JoinSide::Right, + }); + let exec = + SpatialJoinExec::try_new_internal(left, right, on, None, &JoinType::Inner, None, 0)?; + + let swapped = exec.swap_inputs()?; + let spatial_execs = collect_spatial_join_exec(&swapped)?; + assert_eq!(spatial_execs.len(), 1); + + let swapped_exec = spatial_execs[0]; + let SpatialPredicate::KNearestNeighbors(knn) = &swapped_exec.on else { + return sedona_internal_err!("expected KNN predicate"); + }; + + // Children swapped, so probe_side flips. + assert_eq!(knn.probe_side, JoinSide::Left); + + // Expressions are not swapped (remain pointing at original table schemas). + let probe_expr = knn + .left + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + let build_expr = knn + .right + .as_any() + .downcast_ref::() + .expect("expected Column expr"); + assert_eq!(probe_expr.name(), "rgeom"); + assert_eq!(probe_expr.index(), 1); + assert_eq!(build_expr.name(), "lgeom"); + assert_eq!(build_expr.index(), 1); + + Ok(()) + } +} diff --git a/rust/sedona-spatial-join/src/lib.rs b/rust/sedona-spatial-join/src/lib.rs index 2abaf3c43..36ba9bd30 100644 --- a/rust/sedona-spatial-join/src/lib.rs +++ b/rust/sedona-spatial-join/src/lib.rs @@ -19,8 +19,8 @@ pub mod evaluated_batch; pub mod exec; mod index; pub mod operand_evaluator; -pub mod optimizer; pub mod partitioning; +pub mod planner; mod prepare; pub mod refine; pub mod spatial_predicate; @@ -28,7 +28,9 @@ mod stream; pub mod utils; pub use exec::SpatialJoinExec; -pub use optimizer::register_spatial_join_optimizer; + +// Re-export function for register the spatial join planner +pub use planner::register_planner; // Re-export types needed for external usage (e.g., in Comet) pub use index::{SpatialIndex, SpatialJoinBuildMetrics}; diff --git a/rust/sedona-spatial-join/src/planner.rs b/rust/sedona-spatial-join/src/planner.rs new file mode 100644 index 000000000..d9308e746 --- /dev/null +++ b/rust/sedona-spatial-join/src/planner.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion planner integration for Sedona spatial joins. +//! +//! This module wires Sedona's logical optimizer rules and physical planning extensions that +//! can produce `SpatialJoinExec`. + +use datafusion::execution::SessionStateBuilder; + +mod logical_plan_node; +mod optimizer; +mod physical_planner; +mod spatial_expr_utils; + +/// Register Sedona spatial join planning hooks. +/// +/// Enables logical rewrites (to surface join filters) and a query planner extension that can +/// plan `SpatialJoinExec`. +pub fn register_planner(state_builder: SessionStateBuilder) -> SessionStateBuilder { + // Enable the logical rewrite that turns Filter(CrossJoin) into Join(filter=...) + let state_builder = optimizer::register_spatial_join_logical_optimizer(state_builder); + + // Enable planning SpatialJoinExec via an extension node during logical->physical planning. + physical_planner::register_spatial_join_planner(state_builder) +} diff --git a/rust/sedona-spatial-join/src/planner/logical_plan_node.rs b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs new file mode 100644 index 000000000..3c0cb7ced --- /dev/null +++ b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; +use std::fmt; +use std::sync::Arc; + +use datafusion_common::{plan_err, DFSchemaRef, NullEquality, Result}; +use datafusion_expr::logical_plan::UserDefinedLogicalNodeCore; +use datafusion_expr::{Expr, JoinConstraint, JoinType, LogicalPlan}; + +/// Logical extension node used as a planning hook for spatial joins. +/// +/// Carries a join's inputs and filter expression so the physical planner can recognize and plan +/// a `SpatialJoinExec`. +#[derive(PartialEq, Eq, Hash)] +pub(crate) struct SpatialJoinPlanNode { + pub left: LogicalPlan, + pub right: LogicalPlan, + pub join_type: JoinType, + pub filter: Expr, + pub schema: DFSchemaRef, + pub join_constraint: JoinConstraint, + pub null_equality: NullEquality, +} + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +// See https://github.com/apache/datafusion/blob/52.1.0/datafusion/expr/src/logical_plan/plan.rs#L3886 +impl PartialOrd for SpatialJoinPlanNode { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableJoin<'a> { + pub left: &'a LogicalPlan, + pub right: &'a LogicalPlan, + pub filter: &'a Expr, + pub join_type: &'a JoinType, + pub join_constraint: &'a JoinConstraint, + pub null_equality: &'a NullEquality, + } + let comparable_self = ComparableJoin { + left: &self.left, + right: &self.right, + filter: &self.filter, + join_type: &self.join_type, + join_constraint: &self.join_constraint, + null_equality: &self.null_equality, + }; + let comparable_other = ComparableJoin { + left: &other.left, + right: &other.right, + filter: &other.filter, + join_type: &other.join_type, + join_constraint: &other.join_constraint, + null_equality: &self.null_equality, + }; + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } +} + +impl fmt::Debug for SpatialJoinPlanNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + UserDefinedLogicalNodeCore::fmt_for_explain(self, f) + } +} + +impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode { + fn name(&self) -> &str { + "SpatialJoin" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.left, &self.right] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![self.filter.clone()] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "SpatialJoin: join_type={:?}, filter={}", + self.join_type, self.filter + ) + } + + fn with_exprs_and_inputs( + &self, + mut exprs: Vec, + mut inputs: Vec, + ) -> Result { + if exprs.len() != 1 { + return plan_err!("SpatialJoinPlanNode expects 1 expr"); + } + if inputs.len() != 2 { + return plan_err!("SpatialJoinPlanNode expects 2 inputs"); + } + Ok(Self { + left: inputs.swap_remove(0), + right: inputs.swap_remove(0), + join_type: self.join_type, + filter: exprs.swap_remove(0), + schema: Arc::clone(&self.schema), + join_constraint: self.join_constraint, + null_equality: self.null_equality, + }) + } +} diff --git a/rust/sedona-spatial-join/src/planner/optimizer.rs b/rust/sedona-spatial-join/src/planner/optimizer.rs new file mode 100644 index 000000000..8cbc33f62 --- /dev/null +++ b/rust/sedona-spatial-join/src/planner/optimizer.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use std::sync::Arc; + +use crate::planner::logical_plan_node::SpatialJoinPlanNode; +use crate::planner::spatial_expr_utils::collect_spatial_predicate_names; +use crate::planner::spatial_expr_utils::is_spatial_predicate; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::NullEquality; +use datafusion_common::Result; +use datafusion_expr::logical_plan::Extension; +use datafusion_expr::{BinaryExpr, Expr, Operator}; +use datafusion_expr::{Filter, Join, JoinType, LogicalPlan}; +use sedona_common::option::SedonaOptions; + +/// Register only the logical spatial join optimizer rule. +/// +/// This enables building `Join(filter=...)` from patterns like `Filter(CrossJoin)`. +/// It intentionally does not register any physical plan rewrite rules. +pub fn register_spatial_join_logical_optimizer( + session_state_builder: SessionStateBuilder, +) -> SessionStateBuilder { + session_state_builder + .with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin)) + .with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite)) +} +/// Logical optimizer rule that enables spatial join planning. +/// +/// This rule turns eligible `Join(filter=...)` nodes into a `SpatialJoinPlanNode` extension. +#[derive(Default, Debug)] +struct SpatialJoinLogicalRewrite; + +impl OptimizerRule for SpatialJoinLogicalRewrite { + fn name(&self) -> &str { + "spatial_join_logical_rewrite" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let options = config.options(); + let Some(ext) = options.extensions.get::() else { + return Ok(Transformed::no(plan)); + }; + if !ext.spatial_join.enable { + return Ok(Transformed::no(plan)); + } + + let LogicalPlan::Join(join) = &plan else { + return Ok(Transformed::no(plan)); + }; + + // v1: only rewrite joins that already have a spatial predicate in `filter`. + let Some(filter) = join.filter.as_ref() else { + return Ok(Transformed::no(plan)); + }; + + let spatial_predicate_names = collect_spatial_predicate_names(filter); + if spatial_predicate_names.is_empty() { + return Ok(Transformed::no(plan)); + } + + // Join with with equi-join condition and spatial join condition. Only handle it + // when the join condition contains ST_KNN. KNN join is not a regular join and + // ST_KNN is also not a regular predicate. It must be handled by our spatial join exec. + if !join.on.is_empty() && !spatial_predicate_names.contains("st_knn") { + return Ok(Transformed::no(plan)); + } + + // Build new filter expression including equi-join conditions + let filter = filter.clone(); + let eq_op = if join.null_equality == NullEquality::NullEqualsNothing { + Operator::Eq + } else { + Operator::IsNotDistinctFrom + }; + let filter = join.on.iter().fold(filter, |acc, (l, r)| { + let eq_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(l.clone()), + eq_op, + Box::new(r.clone()), + )); + Expr::and(acc, eq_expr) + }); + + let schema = Arc::clone(&join.schema); + let node = SpatialJoinPlanNode { + left: join.left.as_ref().clone(), + right: join.right.as_ref().clone(), + join_type: join.join_type, + filter, + schema, + join_constraint: join.join_constraint, + null_equality: join.null_equality, + }; + + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(node), + }))) + } +} + +/// Logical optimizer rule that enables spatial join planning. +/// +/// This rule turns eligible `Filter(Join(filter=...))` nodes into a `Join(filter=...)` node, +/// so that the spatial join can be rewritten later by [SpatialJoinLogicalRewrite]. +#[derive(Debug, Default)] +struct MergeSpatialProjectionIntoJoin; + +impl OptimizerRule for MergeSpatialProjectionIntoJoin { + fn name(&self) -> &str { + "spatial_join_optimizer" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + /// Try to rewrite the plan containing a spatial Filter on top of a cross join without on or filter + /// to a theta-join with filter. For instance, the following query plan: + /// + /// ```text + /// Filter: st_intersects(l.geom, _scalar_sq_1.geom) + /// Left Join (no on, no filter): + /// TableScan: l projection=[id, geom] + /// SubqueryAlias: __scalar_sq_1 + /// Projection: r.geom + /// Filter: r.id = Int32(1) + /// TableScan: r projection=[id, geom] + /// ``` + /// + /// will be rewritten to + /// + /// ```text + /// Inner Join: Filter: st_intersects(l.geom, _scalar_sq_1.geom) + /// TableScan: l projection=[id, geom] + /// SubqueryAlias: __scalar_sq_1 + /// Projection: r.geom + /// Filter: r.id = Int32(1) + /// TableScan: r projection=[id, geom] + /// ``` + /// + /// This is for enabling this logical join operator to be converted to a [SpatialJoinPlanNode] + /// by [SpatialJoinLogicalRewrite], so that it could subsequently be optimized to a SpatialJoin + /// physical node. + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let options = config.options(); + let Some(extension) = options.extensions.get::() else { + return Ok(Transformed::no(plan)); + }; + if !extension.spatial_join.enable { + return Ok(Transformed::no(plan)); + } + + let LogicalPlan::Filter(Filter { + predicate, input, .. + }) = &plan + else { + return Ok(Transformed::no(plan)); + }; + if !is_spatial_predicate(predicate) { + return Ok(Transformed::no(plan)); + } + + let LogicalPlan::Join(Join { + ref left, + ref right, + ref on, + ref filter, + join_type, + ref join_constraint, + ref null_equality, + .. + }) = input.as_ref() + else { + return Ok(Transformed::no(plan)); + }; + + // Check if this is a suitable join for rewriting + if !matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Right + ) || !on.is_empty() + || filter.is_some() + { + return Ok(Transformed::no(plan)); + } + + let rewritten_plan = Join::try_new( + Arc::clone(left), + Arc::clone(right), + on.clone(), + Some(predicate.clone()), + JoinType::Inner, + *join_constraint, + *null_equality, + )?; + + Ok(Transformed::yes(LogicalPlan::Join(rewritten_plan))) + } +} diff --git a/rust/sedona-spatial-join/src/planner/physical_planner.rs b/rust/sedona-spatial-join/src/planner/physical_planner.rs new file mode 100644 index 000000000..2f0bdddf4 --- /dev/null +++ b/rust/sedona-spatial-join/src/planner/physical_planner.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; + +use async_trait::async_trait; + +use arrow_schema::Schema; + +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::{SessionState, SessionStateBuilder}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}; +use datafusion_common::{plan_err, DFSchema, Result}; +use datafusion_expr::logical_plan::UserDefinedLogicalNode; +use datafusion_expr::LogicalPlan; +use datafusion_physical_expr::create_physical_expr; +use datafusion_physical_plan::joins::utils::JoinFilter; +use datafusion_physical_plan::joins::NestedLoopJoinExec; +use sedona_common::sedona_internal_err; + +use crate::exec::SpatialJoinExec; +use crate::planner::logical_plan_node::SpatialJoinPlanNode; +use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, transform_join_filter}; +use crate::spatial_predicate::SpatialPredicate; +use sedona_common::option::SedonaOptions; + +/// Registers a query planner that can produce [`SpatialJoinExec`] from a logical extension node. +pub fn register_spatial_join_planner(builder: SessionStateBuilder) -> SessionStateBuilder { + builder.with_query_planner(Arc::new(SedonaSpatialQueryPlanner)) +} + +/// Query planner that enables Sedona's spatial join planning. +/// +/// Installs an [`ExtensionPlanner`] that recognizes `SpatialJoinPlanNode` and produces +/// `SpatialJoinExec` when supported and enabled. +pub struct SedonaSpatialQueryPlanner; + +impl fmt::Debug for SedonaSpatialQueryPlanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SedonaSpatialQueryPlanner").finish() + } +} + +#[async_trait] +impl QueryPlanner for SedonaSpatialQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let physical_planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + SpatialJoinExtensionPlanner {}, + )]); + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Physical planner hook for `SpatialJoinPlanNode`. +struct SpatialJoinExtensionPlanner; + +#[async_trait] +impl ExtensionPlanner for SpatialJoinExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + let Some(spatial_node) = node.as_any().downcast_ref::() else { + return Ok(None); + }; + + let Some(ext) = session_state + .config_options() + .extensions + .get::() + else { + return sedona_internal_err!("SedonaOptions not found in session state extensions"); + }; + + if !ext.spatial_join.enable { + return sedona_internal_err!("Spatial join is disabled in SedonaOptions"); + } + + if logical_inputs.len() != 2 || physical_inputs.len() != 2 { + return plan_err!("SpatialJoinPlanNode expects 2 inputs"); + } + + let join_type = &spatial_node.join_type; + + let (physical_left, physical_right) = + (physical_inputs[0].clone(), physical_inputs[1].clone()); + + let join_filter = logical_join_filter_to_physical( + spatial_node, + session_state, + &physical_left, + &physical_right, + )?; + + let Some((spatial_predicate, remainder)) = transform_join_filter(&join_filter) else { + let nlj = NestedLoopJoinExec::try_new( + physical_left, + physical_right, + Some(join_filter), + join_type, + None, + )?; + return Ok(Some(Arc::new(nlj))); + }; + + if !is_spatial_predicate_supported( + &spatial_predicate, + &physical_left.schema(), + &physical_right.schema(), + )? { + let nlj = NestedLoopJoinExec::try_new( + physical_left, + physical_right, + Some(join_filter), + join_type, + None, + )?; + return Ok(Some(Arc::new(nlj))); + } + + let should_swap = !matches!(spatial_predicate, SpatialPredicate::KNearestNeighbors(_)) + && join_type.supports_swap() + && should_swap_join_order(physical_left.as_ref(), physical_right.as_ref())?; + + let exec = SpatialJoinExec::try_new( + physical_left, + physical_right, + spatial_predicate, + remainder, + join_type, + None, + &ext.spatial_join, + )?; + + if should_swap { + exec.swap_inputs().map(Some) + } else { + Ok(Some(Arc::new(exec) as Arc)) + } + } +} + +fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> Result { + let left_stats = left.partition_statistics(None)?; + let right_stats = right.partition_statistics(None)?; + + match ( + left_stats.total_byte_size.get_value(), + right_stats.total_byte_size.get_value(), + ) { + (Some(l), Some(r)) => Ok(l > r), + _ => match ( + left_stats.num_rows.get_value(), + right_stats.num_rows.get_value(), + ) { + (Some(l), Some(r)) => Ok(l > r), + _ => Ok(false), + }, + } +} + +/// This function is mostly taken from the match arm for handling LogicalPlan::Join in +/// https://github.com/apache/datafusion/blob/51.0.0/datafusion/core/src/physical_planner.rs#L1144-L1245 +fn logical_join_filter_to_physical( + plan_node: &SpatialJoinPlanNode, + session_state: &SessionState, + physical_left: &Arc, + physical_right: &Arc, +) -> Result { + let SpatialJoinPlanNode { + left, + right, + filter, + .. + } = plan_node; + + let left_df_schema = left.schema(); + let right_df_schema = right.schema(); + + // Extract columns from filter expression and saved in a HashSet + let cols = filter.column_refs(); + + // Collect left & right field indices, the field indices are sorted in ascending order + let mut left_field_indices = cols + .iter() + .filter_map(|c| left_df_schema.index_of_column(c).ok()) + .collect::>(); + left_field_indices.sort_unstable(); + + let mut right_field_indices = cols + .iter() + .filter_map(|c| right_df_schema.index_of_column(c).ok()) + .collect::>(); + right_field_indices.sort_unstable(); + + // Collect DFFields and Fields required for intermediate schemas + let (filter_df_fields, filter_fields): (Vec<_>, Vec<_>) = left_field_indices + .clone() + .into_iter() + .map(|i| { + ( + left_df_schema.qualified_field(i), + physical_left.schema().field(i).clone(), + ) + }) + .chain(right_field_indices.clone().into_iter().map(|i| { + ( + right_df_schema.qualified_field(i), + physical_right.schema().field(i).clone(), + ) + })) + .unzip(); + let filter_df_fields = filter_df_fields + .into_iter() + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect::>(); + + let metadata: HashMap<_, _> = left_df_schema + .metadata() + .clone() + .into_iter() + .chain(right_df_schema.metadata().clone()) + .collect(); + + // Construct intermediate schemas used for filtering data and + // convert logical expression to physical according to filter schema + let filter_df_schema = DFSchema::new_with_metadata(filter_df_fields, metadata.clone())?; + let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + + let filter_expr = + create_physical_expr(filter, &filter_df_schema, session_state.execution_props())?; + let column_indices = JoinFilter::build_column_indices(left_field_indices, right_field_indices); + + let join_filter = JoinFilter::new(filter_expr, column_indices, Arc::new(filter_schema)); + Ok(join_filter) +} diff --git a/rust/sedona-spatial-join/src/optimizer.rs b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs similarity index 80% rename from rust/sedona-spatial-join/src/optimizer.rs rename to rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs index a5a8baefe..5fd85a50e 100644 --- a/rust/sedona-spatial-join/src/optimizer.rs +++ b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs @@ -14,566 +14,95 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +use std::collections::HashSet; use std::sync::Arc; -use crate::exec::SpatialJoinExec; use crate::spatial_predicate::{ DistancePredicate, KNNPredicate, RelationPredicate, SpatialPredicate, SpatialRelationType, }; -use arrow_schema::{Schema, SchemaRef}; -use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}; -use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan; -use datafusion::{ - config::ConfigOptions, execution::session_state::SessionStateBuilder, - physical_optimizer::PhysicalOptimizerRule, -}; +use arrow_schema::Schema; use datafusion_common::ScalarValue; use datafusion_common::{ tree_node::{Transformed, TreeNode}, JoinSide, }; use datafusion_common::{HashMap, Result}; -use datafusion_expr::{Expr, Filter, Join, JoinType, LogicalPlan, Operator}; +use datafusion_expr::{Expr, Operator}; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::joins::utils::ColumnIndex; -use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec}; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::{joins::utils::JoinFilter, ExecutionPlan}; -use sedona_common::{option::SedonaOptions, sedona_internal_err}; +use datafusion_physical_plan::joins::utils::JoinFilter; +use sedona_common::sedona_internal_err; use sedona_expr::utils::{parse_distance_predicate, ParsedDistancePredicate}; use sedona_schema::datatypes::SedonaType; use sedona_schema::matchers::ArgMatcher; -/// Physical planner extension for spatial joins -/// -/// This extension recognizes nested loop join operations with spatial predicates -/// and converts them to SpatialJoinExec, which is specially optimized for spatial joins. -#[derive(Debug, Default)] -pub struct SpatialJoinOptimizer; - -impl SpatialJoinOptimizer { - pub fn new() -> Self { - Self - } -} - -impl PhysicalOptimizerRule for SpatialJoinOptimizer { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - let Some(extension) = config.extensions.get::() else { - return Ok(plan); - }; - - if extension.spatial_join.enable { - let transformed = plan.transform_up(|plan| self.try_optimize_join(plan, config))?; - Ok(transformed.data) - } else { - Ok(plan) - } - } - - /// A human readable name for this optimizer rule - fn name(&self) -> &str { - "spatial_join_optimizer" - } - - /// A flag to indicate whether the physical planner should valid the rule will not - /// change the schema of the plan after the rewriting. - /// Some of the optimization rules might change the nullable properties of the schema - /// and should disable the schema check. - fn schema_check(&self) -> bool { - true - } -} - -impl OptimizerRule for SpatialJoinOptimizer { - fn name(&self) -> &str { - "spatial_join_optimizer" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - - /// Try to rewrite the plan containing a spatial Filter on top of a cross join without on or filter - /// to a theta-join with filter. For instance, the following query plan: - /// - /// ```text - /// Filter: st_intersects(l.geom, _scalar_sq_1.geom) - /// Left Join (no on, no filter): - /// TableScan: l projection=[id, geom] - /// SubqueryAlias: __scalar_sq_1 - /// Projection: r.geom - /// Filter: r.id = Int32(1) - /// TableScan: r projection=[id, geom] - /// ``` - /// - /// will be rewritten to - /// - /// ```text - /// Inner Join: Filter: st_intersects(l.geom, _scalar_sq_1.geom) - /// TableScan: l projection=[id, geom] - /// SubqueryAlias: __scalar_sq_1 - /// Projection: r.geom - /// Filter: r.id = Int32(1) - /// TableScan: r projection=[id, geom] - /// ``` - /// - /// This is for enabling this logical join operator to be converted to a NestedLoopJoin physical - /// node with a spatial predicate, so that it could subsequently be optimized to a SpatialJoin - /// physical node. Please refer to the `PhysicalOptimizerRule` implementation of this struct - /// and [SpatialJoinOptimizer::try_optimize_join] for details. - fn rewrite( - &self, - plan: LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let options = config.options(); - let Some(extension) = options.extensions.get::() else { - return Ok(Transformed::no(plan)); - }; - if !extension.spatial_join.enable { - return Ok(Transformed::no(plan)); - } - - let LogicalPlan::Filter(Filter { - predicate, input, .. - }) = &plan - else { - return Ok(Transformed::no(plan)); - }; - if !is_spatial_predicate(predicate) { - return Ok(Transformed::no(plan)); - } - - let LogicalPlan::Join(Join { - ref left, - ref right, - ref on, - ref filter, - join_type, - ref join_constraint, - ref null_equality, - .. - }) = input.as_ref() - else { - return Ok(Transformed::no(plan)); - }; - - // Check if this is a suitable join for rewriting - if !matches!( - join_type, - JoinType::Inner | JoinType::Left | JoinType::Right - ) || !on.is_empty() - || filter.is_some() - { - return Ok(Transformed::no(plan)); - } - - let rewritten_plan = Join::try_new( - Arc::clone(left), - Arc::clone(right), - on.clone(), - Some(predicate.clone()), - JoinType::Inner, - *join_constraint, - *null_equality, - )?; - - Ok(Transformed::yes(LogicalPlan::Join(rewritten_plan))) - } -} - -/// Check if a given logical expression contains a spatial predicate component or not. We assume that the given +/// Collect the names of spatial predicates appeared in expr. We assume that the given /// `expr` evaluates to a boolean value and originates from a filter logical node. -fn is_spatial_predicate(expr: &Expr) -> bool { - fn is_distance_expr(expr: &Expr) -> bool { - let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) = expr else { - return false; - }; - func.name().to_lowercase() == "st_distance" - } - - match expr { - Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr { - left, right, op, .. - }) => match op { - Operator::And => is_spatial_predicate(left) || is_spatial_predicate(right), - Operator::Lt | Operator::LtEq => is_distance_expr(left), - Operator::Gt | Operator::GtEq => is_distance_expr(right), - _ => false, - }, - Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) => { - let func_name = func.name().to_lowercase(); - matches!( - func_name.as_str(), - "st_intersects" - | "st_contains" - | "st_within" - | "st_covers" - | "st_covered_by" - | "st_coveredby" - | "st_touches" - | "st_crosses" - | "st_overlaps" - | "st_equals" - | "st_dwithin" - | "st_knn" - ) - } - _ => false, - } -} - -impl SpatialJoinOptimizer { - /// Rewrite `plan` containing NestedLoopJoinExec or HashJoinExec with spatial predicates to SpatialJoinExec. - fn try_optimize_join( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result>> { - // Check if this is a NestedLoopJoinExec that we can convert to spatial join - if let Some(nested_loop_join) = plan.as_any().downcast_ref::() { - if let Some(spatial_join) = - self.try_convert_to_spatial_join(nested_loop_join, config)? - { - return Ok(Transformed::yes(spatial_join)); - } - } - - // Check if this is a HashJoinExec with spatial filter that we can convert to spatial join - if let Some(hash_join) = plan.as_any().downcast_ref::() { - if let Some(spatial_join) = self.try_convert_hash_join_to_spatial(hash_join, config)? { - return Ok(Transformed::yes(spatial_join)); - } - } - - // No optimization applied, return the original plan - Ok(Transformed::no(plan)) - } - - /// Try to convert a NestedLoopJoinExec with spatial predicates as join condition to a SpatialJoinExec. - /// SpatialJoinExec executes the query using an optimized algorithm, which is more efficient than - /// NestedLoopJoinExec. - fn try_convert_to_spatial_join( - &self, - nested_loop_join: &NestedLoopJoinExec, - config: &ConfigOptions, - ) -> Result>> { - let Some(options) = config.extensions.get::() else { - return Ok(None); - }; - - if let Some(join_filter) = nested_loop_join.filter() { - if let Some((spatial_predicate, remainder)) = transform_join_filter(join_filter) { - // The left side of the nested loop join is required to have only one partition, while SpatialJoinExec - // does not have that requirement. SpatialJoinExec can consume the streams on the build side in parallel - // when the build side has multiple partitions. - // If the left side is a CoalescePartitionsExec, we can drop the CoalescePartitionsExec and directly use - // the input. - let left = nested_loop_join.left(); - let left = if let Some(coalesce_partitions) = - left.as_any().downcast_ref::() - { - // Remove unnecessary CoalescePartitionsExec for spatial joins - coalesce_partitions.input() - } else { - left - }; - - let left = left.clone(); - let right = nested_loop_join.right().clone(); - let join_type = nested_loop_join.join_type(); - - // Check if the geospatial types involved in spatial_predicate are supported - if !is_spatial_predicate_supported( - &spatial_predicate, - &left.schema(), - &right.schema(), - )? { - return Ok(None); +pub(crate) fn collect_spatial_predicate_names(expr: &Expr) -> HashSet { + fn collect(expr: &Expr, acc: &mut HashSet) { + match expr { + Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr { + left, right, op, .. + }) => match op { + Operator::And => { + collect(left, acc); + collect(right, acc); } - - // Create the spatial join - let spatial_join = SpatialJoinExec::try_new( - left, - right, - spatial_predicate, - remainder, - join_type, - nested_loop_join.projection().cloned(), - &options.spatial_join, - )?; - - return Ok(Some(Arc::new(spatial_join))); - } - } - - Ok(None) - } - - /// Try to convert a HashJoinExec with spatial predicates in the filter to a SpatialJoinExec. - /// This handles cases where there's an equi-join condition (like c.id = r.id) along with - /// the ST_KNN predicate. We flip them so the spatial predicate drives the join - /// and the equi-conditions become filters. - fn try_convert_hash_join_to_spatial( - &self, - hash_join: &HashJoinExec, - config: &ConfigOptions, - ) -> Result>> { - let Some(options) = config.extensions.get::() else { - return Ok(None); - }; - - // Check if the filter contains spatial predicates - if let Some(join_filter) = hash_join.filter() { - if let Some((spatial_predicate, mut remainder)) = transform_join_filter(join_filter) { - // The transform_join_filter now prioritizes ST_KNN predicates - // Only proceed if we found an ST_KNN (other spatial predicates are left in hash join) - if !matches!(spatial_predicate, SpatialPredicate::KNearestNeighbors(_)) { - return Ok(None); + Operator::Lt | Operator::LtEq => { + if is_distance_expr(left) { + acc.insert("st_dwithin".to_string()); + } } - - // Check if the geospatial types involved in spatial_predicate are supported (planar geometries only) - if !is_spatial_predicate_supported( - &spatial_predicate, - &hash_join.left().schema(), - &hash_join.right().schema(), - )? { - return Ok(None); + Operator::Gt | Operator::GtEq => { + if is_distance_expr(right) { + acc.insert("st_dwithin".to_string()); + } + } + _ => (), + }, + Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) => { + let func_name = func.name().to_lowercase(); + if matches!( + func_name.as_str(), + "st_intersects" + | "st_contains" + | "st_within" + | "st_covers" + | "st_covered_by" + | "st_coveredby" + | "st_touches" + | "st_crosses" + | "st_overlaps" + | "st_equals" + | "st_dwithin" + | "st_knn" + ) { + acc.insert(func_name); } - - // Extract the equi-join conditions and convert them to a filter - let equi_filter = self.create_equi_filter_from_hash_join(hash_join)?; - - // Combine the equi-filter with any existing remainder - remainder = self.combine_filters(remainder, equi_filter)?; - - // Create spatial join where: - // - Spatial predicate (ST_KNN) drives the join - // - Equi-conditions (c.id = r.id) become filters - - // Create SpatialJoinExec without projection first - // Use try_new_with_options to mark this as converted from HashJoin - let spatial_join = Arc::new(SpatialJoinExec::try_new_with_options( - hash_join.left().clone(), - hash_join.right().clone(), - spatial_predicate, - remainder, - hash_join.join_type(), - None, // No projection in SpatialJoinExec - &options.spatial_join, - true, // converted_from_hash_join = true - )?); - - // Now wrap it with ProjectionExec to match HashJoinExec's output schema exactly - let expected_schema = hash_join.schema(); - let spatial_schema = spatial_join.schema(); - - // Create a projection that selects the exact columns HashJoinExec would output - let projection_exec = self.create_schema_matching_projection( - spatial_join, - &expected_schema, - &spatial_schema, - )?; - - return Ok(Some(projection_exec)); } + _ => (), } - - Ok(None) } - /// Create a filter expression from the hash join's equi-join conditions - fn create_equi_filter_from_hash_join( - &self, - hash_join: &HashJoinExec, - ) -> Result> { - let join_keys = hash_join.on(); - - if join_keys.is_empty() { - return Ok(None); - } - - // Build filter expressions from the equi-join conditions - let mut expressions = vec![]; - - // Get the left schema size to calculate right column offsets - let left_schema_size = hash_join.left().schema().fields().len(); - - for (left_key, right_key) in join_keys.iter() { - // Create equality expression: left_key = right_key - // But we need to adjust the column indices for SpatialJoinExec schema - if let (Some(left_col), Some(right_col)) = ( - left_key.as_any().downcast_ref::(), - right_key.as_any().downcast_ref::(), - ) { - // In SpatialJoinExec schema: [left_fields..., right_fields...] - // Left columns keep their indices, right columns get offset by left_schema_size - let left_idx = left_col.index(); - let right_idx = left_schema_size + right_col.index(); - - let left_expr = - Arc::new(Column::new(left_col.name(), left_idx)) as Arc; - let right_expr = - Arc::new(Column::new(right_col.name(), right_idx)) as Arc; - - let eq_expr = Arc::new(BinaryExpr::new(left_expr, Operator::Eq, right_expr)) - as Arc; - - expressions.push(eq_expr); - } - } - - // IMPORTANT: Create column indices for ALL columns in the spatial join schema - // not just the filter columns. This is required by build_batch_from_indices. - let left_schema = hash_join.left().schema(); - let right_schema = hash_join.right().schema(); - let mut column_indices = vec![]; - - // Add all left side columns - for (i, _field) in left_schema.fields().iter().enumerate() { - column_indices.push(ColumnIndex { - index: i, - side: JoinSide::Left, - }); - } - - // Add all right side columns - for (i, _field) in right_schema.fields().iter().enumerate() { - column_indices.push(ColumnIndex { - index: i, - side: JoinSide::Right, - }); - } - - // Combine all conditions with AND - let filter_expr = if expressions.len() == 1 { - expressions.into_iter().next().unwrap() - } else { - expressions - .into_iter() - .reduce(|acc, expr| { - Arc::new(BinaryExpr::new(acc, Operator::And, expr)) as Arc - }) - .unwrap() + fn is_distance_expr(expr: &Expr) -> bool { + let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) = expr else { + return false; }; - - // Create JoinFilter - // IMPORTANT: The filter expression uses spatial join indices (id@0 = id@3) - // So we need to create the filter schema that matches the spatial join schema, - // not the hash join schema - let left_schema = hash_join.left().schema(); - let right_schema = hash_join.right().schema(); - let mut spatial_filter_fields = left_schema.fields().to_vec(); - spatial_filter_fields.extend_from_slice(right_schema.fields()); - let spatial_filter_schema = Arc::new(arrow_schema::Schema::new(spatial_filter_fields)); - - // Filter expression uses spatial join indices (e.g. id@0 = id@3) - // Schema should match the spatial join schema (left + right) - - Ok(Some(JoinFilter::new( - filter_expr, - column_indices, - spatial_filter_schema, - ))) - } - - /// Combine two optional filters with AND - fn combine_filters( - &self, - filter1: Option, - filter2: Option, - ) -> Result> { - match (filter1, filter2) { - (None, None) => Ok(None), - (Some(f), None) | (None, Some(f)) => Ok(Some(f)), - (Some(f1), Some(f2)) => { - // Combine f1 AND f2 - let combined_expr = Arc::new(BinaryExpr::new( - f1.expression().clone(), - Operator::And, - f2.expression().clone(), - )) as Arc; - - // Combine column indices - let mut combined_indices = f1.column_indices().to_vec(); - combined_indices.extend_from_slice(f2.column_indices()); - - Ok(Some(JoinFilter::new( - combined_expr, - combined_indices, - f1.schema().clone(), - ))) - } - } + func.name().to_lowercase() == "st_distance" } - /// Create a ProjectionExec that makes SpatialJoinExec output match HashJoinExec's schema - fn create_schema_matching_projection( - &self, - spatial_join: Arc, - expected_schema: &SchemaRef, - spatial_schema: &SchemaRef, - ) -> Result> { - // The challenge is to map from the expected HashJoinExec schema to SpatialJoinExec schema - // - // Expected schema has fields like: [id, name, name] (with duplicates) - // Spatial schema has fields like: [id, location, name, id, location, name] (left + right) - - // Map the expected schema to spatial schema by matching field names and types - // For fields with duplicate names (like "name"), we need to be careful about ordering - let mut projection_exprs = Vec::new(); - let mut used_spatial_indices = std::collections::HashSet::new(); - - for (expected_idx, expected_field) in expected_schema.fields().iter().enumerate() { - let mut found = false; - - // Try to find the corresponding field in spatial schema - for (spatial_idx, spatial_field) in spatial_schema.fields().iter().enumerate() { - if spatial_field.name() == expected_field.name() - && spatial_field.data_type() == expected_field.data_type() - && !used_spatial_indices.contains(&spatial_idx) - { - let col_expr = Arc::new(Column::new(spatial_field.name(), spatial_idx)) - as Arc; - projection_exprs.push((col_expr, expected_field.name().clone())); - used_spatial_indices.insert(spatial_idx); - found = true; - break; - } - } - - if !found { - return sedona_internal_err!( - "Cannot find matching field for '{}' ({:?}) at position {} in spatial join output. \ - Please check column name mappings and schema compatibility between HashJoinExec and SpatialJoinExec.", - expected_field.name(), - expected_field.data_type(), - expected_idx - ); - } - } - - let projection = ProjectionExec::try_new(projection_exprs, spatial_join)?; - - Ok(Arc::new(projection)) - } + let mut acc = HashSet::new(); + collect(expr, &mut acc); + acc } -/// Helper function to register the spatial join optimizer with a session state -pub fn register_spatial_join_optimizer( - session_state_builder: SessionStateBuilder, -) -> SessionStateBuilder { - session_state_builder - .with_optimizer_rule(Arc::new(SpatialJoinOptimizer::new())) - .with_physical_optimizer_rule(Arc::new(SpatialJoinOptimizer::new())) - .with_physical_optimizer_rule(Arc::new(SanityCheckPlan::new())) +/// Check if a given logical expression contains a spatial predicate component or not. We assume that the given +/// `expr` evaluates to a boolean value and originates from a filter logical node. +pub(crate) fn is_spatial_predicate(expr: &Expr) -> bool { + let pred_names = collect_spatial_predicate_names(expr); + !pred_names.is_empty() } /// Transform the join filter to a spatial predicate and a remainder. @@ -583,7 +112,7 @@ pub fn register_spatial_join_optimizer( /// /// The remainder may reference fewer columns than the original join filter. If that's the case, /// the columns that are not referenced by the remainder will be pruned. -fn transform_join_filter( +pub(crate) fn transform_join_filter( join_filter: &JoinFilter, ) -> Option<(SpatialPredicate, Option)> { let (spatial_predicate, remainder) = @@ -1026,7 +555,7 @@ fn replace_join_filter_expr(expr: &Arc, join_filter: &JoinFilt ) } -fn is_spatial_predicate_supported( +pub(crate) fn is_spatial_predicate_supported( spatial_predicate: &SpatialPredicate, left_schema: &Schema, right_schema: &Schema, @@ -1073,6 +602,7 @@ mod tests { use super::*; use crate::spatial_predicate::{SpatialPredicate, SpatialRelationType}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::config::ConfigOptions; use datafusion_common::{JoinSide, ScalarValue}; use datafusion_expr::Operator; use datafusion_expr::{col, lit, ColumnarValue, Expr, ScalarUDF, SimpleScalarUDF}; @@ -2647,7 +2177,7 @@ mod tests { SpatialRelationType::Intersects, ); let spatial_pred = SpatialPredicate::Relation(rel_pred); - assert!(super::is_spatial_predicate_supported(&spatial_pred, &schema, &schema).unwrap()); + assert!(is_spatial_predicate_supported(&spatial_pred, &schema, &schema).unwrap()); // Geography field (should NOT be supported) let geog_field = WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap(); @@ -2659,12 +2189,10 @@ mod tests { SpatialRelationType::Intersects, ); let spatial_pred_geog = SpatialPredicate::Relation(rel_pred_geog); - assert!(!super::is_spatial_predicate_supported( - &spatial_pred_geog, - &geog_schema, - &geog_schema - ) - .unwrap()); + assert!( + !is_spatial_predicate_supported(&spatial_pred_geog, &geog_schema, &geog_schema) + .unwrap() + ); } #[test] @@ -2686,9 +2214,7 @@ mod tests { false, JoinSide::Left, )); - assert!( - super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap() - ); + assert!(is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap()); // ST_KNN(right, left) let knn_pred = SpatialPredicate::KNearestNeighbors(KNNPredicate::new( @@ -2698,31 +2224,23 @@ mod tests { false, JoinSide::Right, )); - assert!( - super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap() - ); + assert!(is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap()); // ST_KNN with geography (should NOT be supported) let left_geog_schema = Arc::new(Schema::new(vec![WKB_GEOGRAPHY .to_storage_field("geog", false) .unwrap()])); - assert!(!super::is_spatial_predicate_supported( - &knn_pred, - &left_geog_schema, - &right_schema - ) - .unwrap()); + assert!( + !is_spatial_predicate_supported(&knn_pred, &left_geog_schema, &right_schema).unwrap() + ); let right_geog_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap(), ])); - assert!(!super::is_spatial_predicate_supported( - &knn_pred, - &left_schema, - &right_geog_schema - ) - .unwrap()); + assert!( + !is_spatial_predicate_supported(&knn_pred, &left_schema, &right_geog_schema).unwrap() + ); } #[test] @@ -2733,7 +2251,7 @@ mod tests { func: st_intersects_udf, args: vec![col("geom1"), col("geom2")], }); - assert!(super::is_spatial_predicate(&st_intersects_expr)); + assert!(is_spatial_predicate(&st_intersects_expr)); // ST_Distance(geom1, geom2) < 100 should return true let st_distance_udf = create_dummy_st_distance_udf(); @@ -2746,7 +2264,7 @@ mod tests { op: Operator::Lt, right: Box::new(lit(100.0)), }); - assert!(super::is_spatial_predicate(&distance_lt_expr)); + assert!(is_spatial_predicate(&distance_lt_expr)); // ST_Distance(geom1, geom2) > 100 should return false let distance_gt_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr { @@ -2754,7 +2272,7 @@ mod tests { op: Operator::Gt, right: Box::new(lit(100.0)), }); - assert!(!super::is_spatial_predicate(&distance_gt_expr)); + assert!(!is_spatial_predicate(&distance_gt_expr)); // AND expressions with spatial predicates should return true let and_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr { @@ -2762,13 +2280,13 @@ mod tests { op: Operator::And, right: Box::new(col("id").eq(lit(1))), }); - assert!(super::is_spatial_predicate(&and_expr)); + assert!(is_spatial_predicate(&and_expr)); // Non-spatial expressions should return false // Simple column comparison let non_spatial_expr = col("id").eq(lit(1)); - assert!(!super::is_spatial_predicate(&non_spatial_expr)); + assert!(!is_spatial_predicate(&non_spatial_expr)); // Not a spatial relationship function let non_st_func = Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { @@ -2781,7 +2299,7 @@ mod tests { ))), args: vec![col("id")], }); - assert!(!super::is_spatial_predicate(&non_st_func)); + assert!(!is_spatial_predicate(&non_st_func)); // AND expression with no spatial predicates let non_spatial_and = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr { @@ -2789,6 +2307,6 @@ mod tests { op: Operator::And, right: Box::new(col("name").eq(lit("test"))), }); - assert!(!super::is_spatial_predicate(&non_spatial_and)); + assert!(!is_spatial_predicate(&non_spatial_and)); } } diff --git a/rust/sedona-spatial-join/src/spatial_predicate.rs b/rust/sedona-spatial-join/src/spatial_predicate.rs index a6cb24c34..63ccc4718 100644 --- a/rust/sedona-spatial-join/src/spatial_predicate.rs +++ b/rust/sedona-spatial-join/src/spatial_predicate.rs @@ -16,8 +16,11 @@ // under the License. use std::sync::Arc; -use datafusion_common::JoinSide; +use datafusion_common::{JoinSide, Result}; +use datafusion_physical_expr::projection::update_expr; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::projection::ProjectionExpr; +use sedona_common::sedona_internal_err; /// Spatial predicate is the join condition of a spatial join. It can be a distance predicate, /// a relation predicate, or a KNN predicate. @@ -298,7 +301,7 @@ impl KNNPredicate { /// * `right` - Expression for the right side (object) geometry /// * `k` - Number of nearest neighbors to find (literal value) /// * `use_spheroid` - Whether to use spheroid distance (literal value, currently must be false) - /// * `probe_side` - Which execution plan side the probe expression belongs to + /// * `probe_side` - Which execution plan side the probe expression belongs to, cannot be None pub fn new( left: Arc, right: Arc, @@ -306,6 +309,7 @@ impl KNNPredicate { use_spheroid: bool, probe_side: JoinSide, ) -> Self { + assert!(matches!(probe_side, JoinSide::Left | JoinSide::Right)); Self { left, right, @@ -325,3 +329,394 @@ impl std::fmt::Display for KNNPredicate { ) } } + +/// Common operations needed by the planner/executor to keep spatial predicates valid +/// when join inputs are swapped or projected. +pub trait SpatialPredicateTrait: Sized { + /// Returns a semantically equivalent predicate after the join children are swapped. + /// + /// Used by `SpatialJoinExec::swap_inputs` to keep the predicate aligned with the new + /// left/right inputs. + fn swap_for_swapped_children(&self) -> Self; + + /// Rewrites the predicate to reference projected child expressions. + /// + /// Returns `Ok(None)` when the predicate cannot be expressed using the projected inputs + /// (so projection pushdown must be skipped). + fn update_for_child_projections( + &self, + projected_left_exprs: &[ProjectionExpr], + projected_right_exprs: &[ProjectionExpr], + ) -> Result>; +} + +impl SpatialPredicateTrait for SpatialPredicate { + fn swap_for_swapped_children(&self) -> Self { + match self { + SpatialPredicate::Relation(pred) => { + SpatialPredicate::Relation(pred.swap_for_swapped_children()) + } + SpatialPredicate::Distance(pred) => { + SpatialPredicate::Distance(pred.swap_for_swapped_children()) + } + SpatialPredicate::KNearestNeighbors(pred) => { + SpatialPredicate::KNearestNeighbors(pred.swap_for_swapped_children()) + } + } + } + + fn update_for_child_projections( + &self, + projected_left_exprs: &[ProjectionExpr], + projected_right_exprs: &[ProjectionExpr], + ) -> Result> { + match self { + SpatialPredicate::Relation(pred) => Ok(pred + .update_for_child_projections(projected_left_exprs, projected_right_exprs)? + .map(SpatialPredicate::Relation)), + SpatialPredicate::Distance(pred) => Ok(pred + .update_for_child_projections(projected_left_exprs, projected_right_exprs)? + .map(SpatialPredicate::Distance)), + SpatialPredicate::KNearestNeighbors(pred) => Ok(pred + .update_for_child_projections(projected_left_exprs, projected_right_exprs)? + .map(SpatialPredicate::KNearestNeighbors)), + } + } +} + +impl SpatialPredicateTrait for RelationPredicate { + fn swap_for_swapped_children(&self) -> Self { + Self { + left: Arc::clone(&self.right), + right: Arc::clone(&self.left), + relation_type: self.relation_type.invert(), + } + } + + fn update_for_child_projections( + &self, + projected_left_exprs: &[ProjectionExpr], + projected_right_exprs: &[ProjectionExpr], + ) -> Result> { + let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else { + return Ok(None); + }; + let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else { + return Ok(None); + }; + + Ok(Some(Self { + left, + right, + relation_type: self.relation_type, + })) + } +} + +impl SpatialPredicateTrait for DistancePredicate { + fn swap_for_swapped_children(&self) -> Self { + Self { + left: Arc::clone(&self.right), + right: Arc::clone(&self.left), + distance: Arc::clone(&self.distance), + distance_side: self.distance_side.negate(), + } + } + + fn update_for_child_projections( + &self, + projected_left_exprs: &[ProjectionExpr], + projected_right_exprs: &[ProjectionExpr], + ) -> Result> { + let Some(left) = update_expr(&self.left, projected_left_exprs, false)? else { + return Ok(None); + }; + let Some(right) = update_expr(&self.right, projected_right_exprs, false)? else { + return Ok(None); + }; + + let distance = match self.distance_side { + JoinSide::Left => { + let Some(distance) = update_expr(&self.distance, projected_left_exprs, false)? + else { + return Ok(None); + }; + distance + } + JoinSide::Right => { + let Some(distance) = update_expr(&self.distance, projected_right_exprs, false)? + else { + return Ok(None); + }; + distance + } + JoinSide::None => Arc::clone(&self.distance), + }; + + Ok(Some(Self { + left, + right, + distance, + distance_side: self.distance_side, + })) + } +} + +impl SpatialPredicateTrait for KNNPredicate { + fn swap_for_swapped_children(&self) -> Self { + Self { + // Keep query/object expressions stable; only flip which child is considered probe. + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + k: self.k, + use_spheroid: self.use_spheroid, + probe_side: self.probe_side.negate(), + } + } + + fn update_for_child_projections( + &self, + projected_left_exprs: &[ProjectionExpr], + projected_right_exprs: &[ProjectionExpr], + ) -> Result> { + let (query_exprs, object_exprs) = match self.probe_side { + JoinSide::Left => (projected_left_exprs, projected_right_exprs), + JoinSide::Right => (projected_right_exprs, projected_left_exprs), + JoinSide::None => { + return sedona_internal_err!("KNN join requires explicit probe_side designation") + } + }; + + let Some(left) = update_expr(&self.left, query_exprs, false)? else { + return Ok(None); + }; + let Some(right) = update_expr(&self.right, object_exprs, false)? else { + return Ok(None); + }; + + Ok(Some(Self { + left, + right, + k: self.k, + use_spheroid: self.use_spheroid, + probe_side: self.probe_side, + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use datafusion_common::ScalarValue; + use datafusion_physical_expr::expressions::{Column, Literal}; + + fn proj_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn proj_expr(expr: Arc, alias: &str) -> ProjectionExpr { + ProjectionExpr { + expr, + alias: alias.to_string(), + } + } + + fn assert_is_column(expr: &Arc, name: &str, index: usize) { + let col = expr + .as_any() + .downcast_ref::() + .expect("expected Column"); + assert_eq!(col.name(), name); + assert_eq!(col.index(), index); + } + + #[test] + fn relation_rewrite_success() -> Result<()> { + let on = SpatialPredicate::Relation(RelationPredicate { + left: proj_col("a", 1), + right: proj_col("x", 2), + relation_type: SpatialRelationType::Intersects, + }); + + let projected_left_exprs = vec![proj_expr(proj_col("a", 1), "a_new")]; + let projected_right_exprs = vec![proj_expr(proj_col("x", 2), "x_new")]; + + let updated = on + .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)? + .unwrap(); + + let SpatialPredicate::Relation(updated) = updated else { + unreachable!("expected relation") + }; + assert_is_column(&updated.left, "a_new", 0); + assert_is_column(&updated.right, "x_new", 0); + Ok(()) + } + + #[test] + fn relation_rewrite_column_index_unchanged() -> Result<()> { + let on = SpatialPredicate::Relation(RelationPredicate { + left: proj_col("a", 0), + right: proj_col("x", 0), + relation_type: SpatialRelationType::Intersects, + }); + + let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a_new")]; + let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x_new")]; + + let updated = on + .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)? + .unwrap(); + + let SpatialPredicate::Relation(updated) = updated else { + unreachable!("expected relation") + }; + assert_is_column(&updated.left, "a_new", 0); + assert_is_column(&updated.right, "x_new", 0); + Ok(()) + } + + #[test] + fn relation_rewrite_none_when_missing() -> Result<()> { + let on = SpatialPredicate::Relation(RelationPredicate { + left: proj_col("a", 1), + right: proj_col("x", 0), + relation_type: SpatialRelationType::Intersects, + }); + + let projected_left_exprs = vec![proj_expr(proj_col("a", 0), "a0")]; + let projected_right_exprs = vec![proj_expr(proj_col("x", 0), "x0")]; + + assert!(on + .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)? + .is_none()); + Ok(()) + } + + #[test] + fn distance_rewrite_distance_side_left() -> Result<()> { + let on = SpatialPredicate::Distance(DistancePredicate { + left: proj_col("geom", 0), + right: proj_col("geom", 0), + distance: proj_col("dist", 1), + distance_side: JoinSide::Left, + }); + + let projected_left_exprs = vec![ + proj_expr(proj_col("geom", 0), "geom_out"), + proj_expr(proj_col("dist", 1), "dist_out"), + ]; + let projected_right_exprs = vec![proj_expr(proj_col("geom", 0), "geom_r")]; + + let updated = on + .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)? + .unwrap(); + + let SpatialPredicate::Distance(updated) = updated else { + unreachable!("expected distance") + }; + assert_is_column(&updated.left, "geom_out", 0); + assert_is_column(&updated.right, "geom_r", 0); + assert_is_column(&updated.distance, "dist_out", 1); + assert_eq!(updated.distance_side, JoinSide::Left); + Ok(()) + } + + #[test] + fn distance_rewrite_distance_side_none_keeps_literal() -> Result<()> { + let distance_lit: Arc = + Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))); + + let on = SpatialPredicate::Distance(DistancePredicate { + left: proj_col("geom", 2), + right: proj_col("geom", 1), + distance: Arc::clone(&distance_lit), + distance_side: JoinSide::None, + }); + + let projected_left_exprs = vec![proj_expr(proj_col("geom", 2), "geom_out")]; + let projected_right_exprs = vec![proj_expr(proj_col("geom", 1), "geom_r")]; + + let updated = on + .update_for_child_projections(&projected_left_exprs, &projected_right_exprs)? + .unwrap(); + + let SpatialPredicate::Distance(updated) = updated else { + unreachable!("expected distance") + }; + assert_is_column(&updated.left, "geom_out", 0); + assert_is_column(&updated.right, "geom_r", 0); + assert!(Arc::ptr_eq(&updated.distance, &distance_lit)); + assert_eq!(updated.distance_side, JoinSide::None); + Ok(()) + } + + #[test] + fn knn_rewrite_success_probe_left_and_right() -> Result<()> { + let base = SpatialPredicate::KNearestNeighbors(KNNPredicate { + left: proj_col("probe", 1), + right: proj_col("build", 2), + k: 10, + use_spheroid: false, + probe_side: JoinSide::Left, + }); + + let left_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out")]; + let right_exprs = vec![proj_expr(proj_col("build", 2), "build_out")]; + + let updated = base + .update_for_child_projections(&left_exprs, &right_exprs)? + .unwrap(); + let SpatialPredicate::KNearestNeighbors(updated) = updated else { + unreachable!("expected knn") + }; + assert_is_column(&updated.left, "probe_out", 0); + assert_is_column(&updated.right, "build_out", 0); + assert_eq!(updated.probe_side, JoinSide::Left); + + let base = SpatialPredicate::KNearestNeighbors(KNNPredicate { + left: proj_col("probe", 1), + right: proj_col("build", 2), + k: 10, + use_spheroid: false, + probe_side: JoinSide::Right, + }); + + // For probe_side=Right: predicate.left (probe) is rewritten using right projections, + // and predicate.right (build) is rewritten using left projections. + let left_exprs = vec![proj_expr(proj_col("build", 2), "build_out_l")]; + let right_exprs = vec![proj_expr(proj_col("probe", 1), "probe_out_r")]; + let updated = base + .update_for_child_projections(&left_exprs, &right_exprs)? + .unwrap(); + let SpatialPredicate::KNearestNeighbors(updated) = updated else { + unreachable!("expected knn") + }; + assert_is_column(&updated.left, "probe_out_r", 0); + assert_is_column(&updated.right, "build_out_l", 0); + assert_eq!(updated.probe_side, JoinSide::Right); + + Ok(()) + } + + #[test] + fn knn_rewrite_errors_on_none_probe_side() { + let on = SpatialPredicate::KNearestNeighbors(KNNPredicate { + left: proj_col("probe", 0), + right: proj_col("build", 0), + k: 10, + use_spheroid: false, + probe_side: JoinSide::None, + }); + + let left_exprs = vec![proj_expr(proj_col("probe", 0), "probe_out")]; + let right_exprs = vec![proj_expr(proj_col("build", 0), "build_out")]; + + let err = on + .update_for_child_projections(&left_exprs, &right_exprs) + .expect_err("expected error"); + let msg = err.to_string(); + assert!(msg.contains("KNN join requires explicit probe_side designation")); + } +} diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs b/rust/sedona-spatial-join/src/utils/join_utils.rs index 87aaa9ae8..55682ca23 100644 --- a/rust/sedona-spatial-join/src/utils/join_utils.rs +++ b/rust/sedona-spatial-join/src/utils/join_utils.rs @@ -27,16 +27,24 @@ use arrow::buffer::NullBuffer; use arrow::compute::{self, take}; use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type}; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, UInt32Array, UInt64Array}; +use arrow_schema::SchemaRef; use datafusion_common::cast::as_boolean_array; use datafusion_common::{JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::execution_plan::Boundedness; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::utils::{ adjust_right_output_partitioning, ColumnIndex, JoinFilter, }; +use datafusion_physical_plan::projection::{ + join_allows_pushdown, join_table_borders, new_join_children, physical_to_column_exprs, + update_join_filter, ProjectionExec, +}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use crate::spatial_predicate::SpatialPredicateTrait; +use crate::SpatialPredicate; + /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and /// use the bit map to generate the part of result of the join. /// @@ -101,6 +109,12 @@ pub(crate) fn apply_join_filter_to_indices( filter: &JoinFilter, build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { + // Forked from DataFusion 50.2.0 `apply_join_filter_to_indices`. + // https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs + // + // Changes vs upstream: + // - Removes the `max_intermediate_size` parameter and its chunking logic. + // - Calls our forked `build_batch_from_indices(..., join_type)` (needed for mark-join semantics). if build_indices.is_empty() && probe_indices.is_empty() { return Ok((build_indices, probe_indices)); }; @@ -142,6 +156,14 @@ pub(crate) fn build_batch_from_indices( build_side: JoinSide, join_type: JoinType, ) -> Result { + // Forked from DataFusion 50.2.0 `build_batch_from_indices`. + // https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs + // + // Changes vs upstream: + // - Adds the `join_type` parameter so we can special-case mark joins. + // - Fixes `RightMark` mark-column construction: for right-mark joins, the mark column must + // reflect match status for the *right* rows, so we build it from `build_indices` (the + // build-side indices) rather than `probe_indices`. if schema.fields().is_empty() { let options = RecordBatchOptions::new() .with_match_field_names(true) @@ -202,6 +224,13 @@ pub(crate) fn adjust_indices_by_join_type( join_type: JoinType, preserve_order_for_right: bool, ) -> Result<(UInt64Array, UInt32Array)> { + // Forked from DataFusion 50.2.0 `adjust_indices_by_join_type`. + // https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs + // + // Changes vs upstream: + // - Fixes `RightMark` handling to match our `SpatialJoinStream` contract: + // `right_indices` becomes the probe row indices (`adjust_range`), and `left_indices` is a + // mark array (null/non-null) indicating match status. match join_type { JoinType::Inner => { // matched @@ -381,6 +410,12 @@ pub(crate) fn get_mark_indices( where NativeAdapter: From<::Native>, { + // Forked from DataFusion 50.2.0 `get_mark_indices`. + // https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs + // + // Changes vs upstream: + // - Generalizes the output array element type (generic `R`) so we can build mark arrays of + // different physical types while still using the null buffer to encode match status. let mut bitmap = build_range_bitmap(range, input_indices); PrimitiveArray::new( vec![R::Native::default(); range.len()].into(), @@ -463,21 +498,49 @@ pub(crate) fn asymmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, + probe_side: JoinSide, ) -> Result { let result = match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left.schema().fields().len(), - )?, + JoinType::Inner => { + if probe_side == JoinSide::Right { + adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + )? + } else { + left.output_partitioning().clone() + } + } + JoinType::Right => { + if probe_side == JoinSide::Right { + adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + )? + } else { + Partitioning::UnknownPartitioning(left.output_partitioning().partition_count()) + } + } JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { - right.output_partitioning().clone() + if probe_side == JoinSide::Right { + right.output_partitioning().clone() + } else { + Partitioning::UnknownPartitioning(left.output_partitioning().partition_count()) + } } - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full - | JoinType::LeftMark => { - Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + if probe_side == JoinSide::Left { + left.output_partitioning().clone() + } else { + Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) + } + } + JoinType::Full => { + if probe_side == JoinSide::Right { + Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) + } else { + Partitioning::UnknownPartitioning(left.output_partitioning().partition_count()) + } } }; Ok(result) @@ -517,3 +580,826 @@ pub(crate) fn boundedness_from_children<'a>( Boundedness::Bounded } } + +pub(crate) fn compute_join_emission_type( + left: &Arc, + right: &Arc, + join_type: JoinType, + probe_side: JoinSide, +) -> EmissionType { + let (build, probe) = if probe_side == JoinSide::Left { + (right, left) + } else { + (left, right) + }; + + if build.boundedness().is_unbounded() { + return EmissionType::Final; + } + + if probe.pipeline_behavior() == EmissionType::Incremental { + match join_type { + // If we only need to generate matched rows from the probe side, + // we can emit rows incrementally. + JoinType::Inner => EmissionType::Incremental, + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + if probe_side == JoinSide::Right { + EmissionType::Incremental + } else { + EmissionType::Both + } + } + // If we need to generate unmatched rows from the *build side*, + // we need to emit them at the end. + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + if probe_side == JoinSide::Left { + EmissionType::Incremental + } else { + EmissionType::Both + } + } + JoinType::Full => EmissionType::Both, + } + } else { + probe.pipeline_behavior() + } +} + +/// Data required to push down a projection through a spatial join. +/// This is mostly taken from https://github.com/apache/datafusion/blob/51.0.0/datafusion/physical-plan/src/projection.rs +pub(crate) struct JoinPushdownData { + pub projected_left_child: ProjectionExec, + pub projected_right_child: ProjectionExec, + pub join_filter: Option, + pub join_on: SpatialPredicate, +} + +/// Push down the given `projection` through the spatial join. +/// This code is adapted from https://github.com/apache/datafusion/blob/51.0.0/datafusion/physical-plan/src/projection.rs +pub(crate) fn try_pushdown_through_join( + projection: &ProjectionExec, + join_left: &Arc, + join_right: &Arc, + join_schema: &SchemaRef, + join_type: JoinType, + join_filter: Option<&JoinFilter>, + join_on: &SpatialPredicate, +) -> Result> { + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + // Mark joins produce a synthetic column that does not belong to either child. This synthetic + // `mark` column will make `new_join_children` fail, so we skip pushdown for such joins. + // This limitation if inherited from DataFusion's builtin `try_pushdown_through_join`. + if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { + return Ok(None); + } + + let (far_right_left_col_ind, far_left_right_col_ind) = + join_table_borders(join_left.schema().fields().len(), &projection_as_columns); + + if !join_allows_pushdown( + &projection_as_columns, + join_schema, + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let (projected_left_child, projected_right_child) = new_join_children( + &projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + join_left, + join_right, + )?; + + let new_filter = if let Some(filter) = join_filter { + let left_cols = &projection_as_columns[0..=far_right_left_col_ind as usize]; + let right_cols = &projection_as_columns[far_left_right_col_ind as usize..]; + match update_join_filter( + left_cols, + right_cols, + filter, + join_left.schema().fields().len(), + ) { + Some(updated) => Some(updated), + None => return Ok(None), + } + } else { + None + }; + + let projected_left_exprs = projected_left_child.expr(); + let projected_right_exprs = projected_right_child.expr(); + let Some(new_on) = + join_on.update_for_child_projections(projected_left_exprs, projected_right_exprs)? + else { + return Ok(None); + }; + + Ok(Some(JoinPushdownData { + projected_left_child, + projected_right_child, + join_filter: new_filter, + join_on: new_on, + })) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::DataType; + use arrow_schema::Field; + use arrow_schema::SchemaRef; + use datafusion_common::ScalarValue; + use datafusion_expr::JoinType; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::Partitioning; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::projection::ProjectionExpr; + use datafusion_physical_plan::repartition::RepartitionExec; + use datafusion_physical_plan::DisplayAs; + use datafusion_physical_plan::DisplayFormatType; + use datafusion_physical_plan::PlanProperties; + + use crate::spatial_predicate::{RelationPredicate, SpatialRelationType}; + + fn make_schema(prefix: &str, num_fields: usize) -> SchemaRef { + Arc::new(Schema::new( + (0..num_fields) + .map(|i| Field::new(format!("{prefix}{i}"), DataType::Int32, true)) + .collect::>(), + )) + } + + fn assert_hash_partitioning_column_indices( + partitioning: &Partitioning, + expected_indices: &[usize], + expected_partition_count: usize, + ) { + match partitioning { + Partitioning::Hash(exprs, size) => { + assert_eq!(*size, expected_partition_count); + assert_eq!(exprs.len(), expected_indices.len()); + for (expr, expected_idx) in exprs.iter().zip(expected_indices.iter()) { + let col = expr + .as_any() + .downcast_ref::() + .expect("expected Column physical expr"); + assert_eq!(col.index(), *expected_idx); + } + } + other => panic!("expected Hash partitioning, got {other:?}"), + } + } + + fn make_join_schema(left: &SchemaRef, right: &SchemaRef) -> SchemaRef { + let mut fields = Vec::with_capacity(left.fields().len() + right.fields().len()); + fields.extend(left.fields().iter().cloned()); + fields.extend(right.fields().iter().cloned()); + Arc::new(Schema::new(fields)) + } + + fn make_join_projection( + join_schema: &SchemaRef, + indices: &[usize], + aliases: &[&str], + ) -> Result { + assert_eq!(indices.len(), aliases.len()); + let exprs = indices + .iter() + .zip(aliases.iter()) + .map(|(index, alias)| { + let field = join_schema.field(*index); + ProjectionExpr { + expr: Arc::new(Column::new(field.name(), *index)), + alias: (*alias).to_string(), + } + }) + .collect::>(); + ProjectionExec::try_new(exprs, Arc::new(EmptyExec::new(Arc::clone(join_schema)))) + } + + fn make_join_filter( + left_indices: Vec, + right_indices: Vec, + schema: SchemaRef, + ) -> JoinFilter { + let expression: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new(schema.field(0).name(), 0)), + Operator::Eq, + Arc::new(Column::new(schema.field(1).name(), 1)), + )); + JoinFilter::new( + expression, + JoinFilter::build_column_indices(left_indices, right_indices), + schema, + ) + } + + fn assert_is_column_expr(expr: &Arc, name: &str, index: usize) { + let col = expr + .as_any() + .downcast_ref::() + .expect("expected Column"); + assert_eq!(col.name(), name); + assert_eq!(col.index(), index); + } + + #[derive(Debug, Clone)] + struct PropertiesOnlyExec { + schema: SchemaRef, + properties: PlanProperties, + } + + impl PropertiesOnlyExec { + fn new(schema: SchemaRef, boundedness: Boundedness, emission_type: EmissionType) -> Self { + let schema_ref = Arc::clone(&schema); + let properties = PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + emission_type, + boundedness, + ); + Self { + schema: schema_ref, + properties, + } + } + } + + impl DisplayAs for PropertiesOnlyExec { + fn fmt_as(&self, _t: DisplayFormatType, _f: &mut std::fmt::Formatter) -> std::fmt::Result { + Ok(()) + } + } + + impl ExecutionPlan for PropertiesOnlyExec { + fn name(&self) -> &'static str { + "PropertiesOnlyExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("PropertiesOnlyExec is for properties tests only") + } + + fn statistics(&self) -> Result { + Ok(datafusion_common::Statistics::new_unknown( + self.schema().as_ref(), + )) + } + + fn partition_statistics( + &self, + _partition: Option, + ) -> Result { + Ok(datafusion_common::Statistics::new_unknown( + self.schema().as_ref(), + )) + } + } + + #[test] + fn adjust_right_output_partitioning_offsets_hash_columns() -> Result<()> { + let right_part = Partitioning::Hash(vec![Arc::new(Column::new("r0", 0))], 8); + let adjusted = adjust_right_output_partitioning(&right_part, 3)?; + assert_hash_partitioning_column_indices(&adjusted, &[3], 8); + + let right_part_multi = Partitioning::Hash( + vec![ + Arc::new(Column::new("r0", 0)), + Arc::new(Column::new("r2", 2)), + ], + 16, + ); + let adjusted_multi = adjust_right_output_partitioning(&right_part_multi, 5)?; + assert_hash_partitioning_column_indices(&adjusted_multi, &[5, 7], 16); + Ok(()) + } + + #[test] + fn adjust_right_output_partitioning_passthrough_non_hash() -> Result<()> { + let right_part = Partitioning::UnknownPartitioning(4); + let adjusted = adjust_right_output_partitioning(&right_part, 10)?; + assert!(matches!(adjusted, Partitioning::UnknownPartitioning(4))); + Ok(()) + } + + #[test] + fn asymmetric_join_output_partitioning_all_combinations_hash_keys() -> Result<()> { + // Left is partitioned by l1, right is partitioned by r0. + // We validate output partitioning for all (probe_side, join_type) combinations. + let left_partitions = 3; + let right_partitions = 5; + + let left_schema = make_schema("l", 2); + let left_len = left_schema.fields().len(); + let left_input: Arc = Arc::new(EmptyExec::new(left_schema)); + let left: Arc = Arc::new(RepartitionExec::try_new( + left_input, + Partitioning::Hash(vec![Arc::new(Column::new("l1", 1))], left_partitions), + )?); + + let right_input: Arc = Arc::new(EmptyExec::new(make_schema("r", 1))); + let right: Arc = Arc::new(RepartitionExec::try_new( + right_input, + Partitioning::Hash(vec![Arc::new(Column::new("r0", 0))], right_partitions), + )?); + + let join_types = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::LeftMark, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::RightMark, + ]; + let probe_sides = [JoinSide::Left, JoinSide::Right]; + + for join_type in join_types { + for probe_side in probe_sides { + let out = + asymmetric_join_output_partitioning(&left, &right, &join_type, probe_side)?; + + match (join_type, probe_side) { + (JoinType::Inner, JoinSide::Right) => { + // join output schema is left + right, so offset right partition key + assert_hash_partitioning_column_indices( + &out, + &[left_len], + right_partitions, + ); + } + (JoinType::Inner, JoinSide::Left) => { + assert_hash_partitioning_column_indices(&out, &[1], left_partitions); + } + + (JoinType::Right, JoinSide::Right) => { + assert_hash_partitioning_column_indices( + &out, + &[left_len], + right_partitions, + ); + } + (JoinType::Right, JoinSide::Left) => { + assert!(matches!( + out, + Partitioning::UnknownPartitioning(n) if n == left_partitions + )); + } + + ( + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark, + JoinSide::Right, + ) => { + // right-only output schema (plus mark column for RightMark), so no offset + assert_hash_partitioning_column_indices(&out, &[0], right_partitions); + } + ( + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark, + JoinSide::Left, + ) => { + assert!(matches!( + out, + Partitioning::UnknownPartitioning(n) if n == left_partitions + )); + } + + ( + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark, + JoinSide::Left, + ) => { + assert_hash_partitioning_column_indices(&out, &[1], left_partitions); + } + ( + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark, + JoinSide::Right, + ) => { + assert!(matches!( + out, + Partitioning::UnknownPartitioning(n) if n == right_partitions + )); + } + + (JoinType::Full, JoinSide::Left) => { + assert!(matches!( + out, + Partitioning::UnknownPartitioning(n) if n == left_partitions + )); + } + (JoinType::Full, JoinSide::Right) => { + assert!(matches!( + out, + Partitioning::UnknownPartitioning(n) if n == right_partitions + )); + } + + _ => unreachable!("unexpected probe_side: {probe_side:?}"), + } + } + } + + Ok(()) + } + + #[test] + fn compute_join_emission_type_prefers_final_for_unbounded_build() { + let schema = make_schema("x", 1); + let build: Arc = Arc::new(PropertiesOnlyExec::new( + Arc::clone(&schema), + datafusion_physical_plan::execution_plan::Boundedness::Unbounded { + requires_infinite_memory: false, + }, + EmissionType::Incremental, + )); + let probe: Arc = Arc::new(PropertiesOnlyExec::new( + schema, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + + assert_eq!( + compute_join_emission_type(&build, &probe, JoinType::Inner, JoinSide::Right), + EmissionType::Final + ); + assert_eq!( + compute_join_emission_type(&probe, &build, JoinType::Inner, JoinSide::Left), + EmissionType::Final + ); + } + + #[test] + fn compute_join_emission_type_uses_probe_behavior_for_inner_join() { + let schema = make_schema("x", 1); + let build: Arc = Arc::new(PropertiesOnlyExec::new( + Arc::clone(&schema), + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + for probe_emission_type in [EmissionType::Incremental, EmissionType::Both] { + let probe: Arc = Arc::new(PropertiesOnlyExec::new( + Arc::clone(&schema), + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + probe_emission_type, + )); + + assert_eq!( + compute_join_emission_type(&build, &probe, JoinType::Inner, JoinSide::Right), + probe_emission_type + ); + assert_eq!( + compute_join_emission_type(&probe, &build, JoinType::Inner, JoinSide::Left), + probe_emission_type + ); + } + } + + #[test] + fn compute_join_emission_type_incremental_when_join_type_and_probe_side_matches() { + let schema = make_schema("x", 1); + let left: Arc = Arc::new(PropertiesOnlyExec::new( + Arc::clone(&schema), + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + let right: Arc = Arc::new(PropertiesOnlyExec::new( + schema, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + + for join_type in [ + JoinType::Right, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::RightMark, + ] { + assert_eq!( + compute_join_emission_type(&left, &right, join_type, JoinSide::Right), + EmissionType::Incremental + ); + assert_eq!( + compute_join_emission_type(&left, &right, join_type, JoinSide::Left), + EmissionType::Both + ); + } + + for join_type in [ + JoinType::Left, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::LeftMark, + ] { + assert_eq!( + compute_join_emission_type(&left, &right, join_type, JoinSide::Left), + EmissionType::Incremental + ); + assert_eq!( + compute_join_emission_type(&left, &right, join_type, JoinSide::Right), + EmissionType::Both + ); + } + } + + #[test] + fn compute_join_emission_type_always_both_for_full_outer_join() { + let schema = make_schema("x", 1); + let left: Arc = Arc::new(PropertiesOnlyExec::new( + Arc::clone(&schema), + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + let right: Arc = Arc::new(PropertiesOnlyExec::new( + schema, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + EmissionType::Incremental, + )); + + assert_eq!( + compute_join_emission_type(&left, &right, JoinType::Full, JoinSide::Left), + EmissionType::Both + ); + assert_eq!( + compute_join_emission_type(&left, &right, JoinType::Full, JoinSide::Right), + EmissionType::Both + ); + } + + #[test] + fn try_pushdown_through_join_updates_children_filter_and_predicate() -> Result<()> { + let left_schema = make_schema("l", 2); + let right_schema = make_schema("r", 2); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let projection = make_join_projection(&join_schema, &[1, 2], &["l1_out", "r0_out"])?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l1", 1)), + Arc::new(Column::new("r0", 0)), + SpatialRelationType::Intersects, + )); + + let filter_schema = Arc::new(Schema::new(vec![ + Field::new("l1", DataType::Int32, true), + Field::new("r0", DataType::Int32, true), + ])); + let join_filter = make_join_filter(vec![1], vec![0], filter_schema); + + let pushdown = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::Inner, + Some(&join_filter), + &join_on, + )? + .expect("expected pushdown"); + + assert_eq!(pushdown.projected_left_child.expr().len(), 1); + let left_proj = &pushdown.projected_left_child.expr()[0]; + assert_eq!(left_proj.alias, "l1_out"); + let left_col = left_proj + .expr + .as_any() + .downcast_ref::() + .expect("expected Column"); + assert_eq!(left_col.name(), "l1"); + assert_eq!(left_col.index(), 1); + + assert_eq!(pushdown.projected_right_child.expr().len(), 1); + let right_proj = &pushdown.projected_right_child.expr()[0]; + assert_eq!(right_proj.alias, "r0_out"); + let right_col = right_proj + .expr + .as_any() + .downcast_ref::() + .expect("expected Column"); + assert_eq!(right_col.name(), "r0"); + assert_eq!(right_col.index(), 0); + + let updated_filter = pushdown.join_filter.expect("expected updated filter"); + let indices = updated_filter.column_indices(); + assert_eq!(indices.len(), 2); + assert_eq!(indices[0].side, JoinSide::Left); + assert_eq!(indices[0].index, 0); + assert_eq!(indices[1].side, JoinSide::Right); + assert_eq!(indices[1].index, 0); + + let SpatialPredicate::Relation(updated_on) = pushdown.join_on else { + unreachable!("expected relation predicate") + }; + assert_is_column_expr(&updated_on.left, "l1_out", 0); + assert_is_column_expr(&updated_on.right, "r0_out", 0); + + Ok(()) + } + + #[test] + fn try_pushdown_through_join_skips_mark_join() -> Result<()> { + let left_schema = make_schema("l", 1); + let right_schema = make_schema("r", 1); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + let projection = make_join_projection(&join_schema, &[0, 1], &["l0", "r0"])?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l0", 0)), + Arc::new(Column::new("r0", 0)), + SpatialRelationType::Intersects, + )); + + let result = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::LeftMark, + None, + &join_on, + )?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn try_pushdown_through_join_requires_column_projection() -> Result<()> { + let left_schema = make_schema("l", 1); + let right_schema = make_schema("r", 1); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + alias: "lit".to_string(), + }], + Arc::new(EmptyExec::new(Arc::clone(&join_schema))), + )?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l0", 0)), + Arc::new(Column::new("r0", 0)), + SpatialRelationType::Intersects, + )); + + let result = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::Inner, + None, + &join_on, + )?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn try_pushdown_through_join_requires_projection_narrowing() -> Result<()> { + let left_schema = make_schema("l", 2); + let right_schema = make_schema("r", 2); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let projection = + make_join_projection(&join_schema, &[0, 1, 2, 3], &["l0", "l1", "r0", "r1"])?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l0", 0)), + Arc::new(Column::new("r0", 0)), + SpatialRelationType::Intersects, + )); + + let result = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::Inner, + None, + &join_on, + )?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn try_pushdown_through_join_fails_when_filter_columns_missing() -> Result<()> { + let left_schema = make_schema("l", 2); + let right_schema = make_schema("r", 2); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let projection = make_join_projection(&join_schema, &[1, 3], &["l1_out", "r1_out"])?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l1", 1)), + Arc::new(Column::new("r1", 1)), + SpatialRelationType::Intersects, + )); + + let filter_schema = Arc::new(Schema::new(vec![ + Field::new("l1", DataType::Int32, true), + Field::new("r0", DataType::Int32, true), + ])); + let join_filter = make_join_filter(vec![1], vec![0], filter_schema); + + let result = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::Inner, + Some(&join_filter), + &join_on, + )?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn try_pushdown_through_join_fails_when_predicate_columns_missing() -> Result<()> { + let left_schema = make_schema("l", 2); + let right_schema = make_schema("r", 2); + let join_schema = make_join_schema(&left_schema, &right_schema); + let join_left: Arc = Arc::new(EmptyExec::new(Arc::clone(&left_schema))); + let join_right: Arc = + Arc::new(EmptyExec::new(Arc::clone(&right_schema))); + + let projection = make_join_projection(&join_schema, &[1, 3], &["l1_out", "r1_out"])?; + + let join_on = SpatialPredicate::Relation(RelationPredicate::new( + Arc::new(Column::new("l1", 1)), + Arc::new(Column::new("r0", 0)), + SpatialRelationType::Intersects, + )); + + let result = try_pushdown_through_join( + &projection, + &join_left, + &join_right, + &join_schema, + JoinType::Inner, + None, + &join_on, + )?; + assert!(result.is_none()); + Ok(()) + } +} diff --git a/rust/sedona/src/context.rs b/rust/sedona/src/context.rs index 6efd3ba4f..3757ea82e 100644 --- a/rust/sedona/src/context.rs +++ b/rust/sedona/src/context.rs @@ -95,7 +95,7 @@ impl SedonaContext { // Register the spatial join planner extension #[cfg(feature = "spatial-join")] { - state_builder = sedona_spatial_join::register_spatial_join_optimizer(state_builder); + state_builder = sedona_spatial_join::register_planner(state_builder); } let mut state = state_builder.build();