Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 183 additions & 20 deletions rust/sedona-spatial-join/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ impl SpatialJoinExec {
let cache = Self::compute_properties(
&left,
&right,
&on,
Arc::clone(&join_schema),
*join_type,
projection.as_ref(),
Expand Down Expand Up @@ -236,9 +237,11 @@ impl SpatialJoinExec {
///
/// When converted from HashJoin, we preserve HashJoin's equivalence properties by extracting
/// equality conditions from the filter.
#[allow(clippy::too_many_arguments)]
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
on: &SpatialPredicate,
schema: SchemaRef,
join_type: JoinType,
projection: Option<&Vec<usize>>,
Expand All @@ -265,7 +268,13 @@ impl SpatialJoinExec {

// Use symmetric partitioning (like HashJoin) when converted from HashJoin
// Otherwise use asymmetric partitioning (like NestedLoopJoin)
let mut output_partitioning = if converted_from_hash_join {
let mut output_partitioning = if let SpatialPredicate::KNearestNeighbors(knn) = on {
match knn.probe_side {
JoinSide::Left => left.output_partitioning().clone(),
JoinSide::Right => right.output_partitioning().clone(),
_ => asymmetric_join_output_partitioning(left, right, &join_type),
}
} else if converted_from_hash_join {
// Replicate HashJoin's symmetric partitioning logic
// HashJoin preserves partitioning from both sides for inner joins
// and from one side for outer joins
Expand Down Expand Up @@ -467,7 +476,6 @@ impl ExecutionPlan for SpatialJoinExec {
})?
};

// Column indices for regular joins - no swapping needed
let column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
Expand Down Expand Up @@ -559,30 +567,14 @@ impl SpatialJoinExec {
})?
};

// Handle column indices for KNN - need to swap if we swapped execution plans
let mut column_indices_after_projection = match &self.projection {
let column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
.map(|i| self.column_indices[*i].clone())
.collect(),
None => self.column_indices.clone(),
};

// If we swapped execution plans for KNN, we need to swap the column indices too
if !actual_probe_plan_is_left {
for col_idx in &mut column_indices_after_projection {
match col_idx.side {
datafusion_common::JoinSide::Left => {
col_idx.side = datafusion_common::JoinSide::Right
}
datafusion_common::JoinSide::Right => {
col_idx.side = datafusion_common::JoinSide::Left
}
datafusion_common::JoinSide::None => {} // No change needed
}
}
}

let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics);
let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?;

Expand Down Expand Up @@ -614,16 +606,19 @@ impl SpatialJoinExec {

#[cfg(test)]
mod tests {
use arrow_array::RecordBatch;
use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::{
catalog::{MemTable, TableProvider},
execution::SessionStateBuilder,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::ColumnarValue;
use geo::{Distance, Euclidean};
use geo_types::{Coord, Rect};
use rstest::rstest;
use sedona_geo::to_geo::GeoTypesExecutor;
use sedona_geometry::types::GeometryTypeId;
use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
use sedona_testing::datagen::RandomPartitionedDataBuilder;
Expand Down Expand Up @@ -691,6 +686,40 @@ mod tests {
Ok((left_data, right_data))
}

/// Creates test data for KNN join (Point-Point)
fn create_knn_test_data(
size_range: (f64, f64),
sedona_type: SedonaType,
) -> Result<(TestPartitions, TestPartitions)> {
let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 100.0 });

let left_data = RandomPartitionedDataBuilder::new()
.seed(1)
.num_partitions(2)
.batches_per_partition(2)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Point)
.sedona_type(sedona_type.clone())
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
.build()?;

let right_data = RandomPartitionedDataBuilder::new()
.seed(2)
.num_partitions(4)
.batches_per_partition(4)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Point)
.sedona_type(sedona_type)
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
.build()?;

Ok((left_data, right_data))
}

fn setup_context(
options: Option<SpatialJoinOptions>,
batch_size: usize,
Expand Down Expand Up @@ -1173,4 +1202,138 @@ mod tests {
})?;
Ok(spatial_join_execs)
}

fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, geo::Geometry<f64>)> {
let mut result = Vec::new();
for partition in partitions {
for batch in partition {
let id_idx = batch.schema().index_of("id").expect("Id column not found");
let ids = batch
.column(id_idx)
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.expect("Column 'id' should be Int32");

let geom_idx = batch
.schema()
.index_of("geometry")
.expect("Geometry column not found");

let geoms_col = batch.column(geom_idx);
let geom_type = SedonaType::from_storage_field(batch.schema().field(geom_idx))
.expect("Failed to get SedonaType from geometry field");
let arg_types = [geom_type];
let arg_values = [ColumnarValue::Array(Arc::clone(geoms_col))];

let executor = GeoTypesExecutor::new(&arg_types, &arg_values);
let mut id_iter = ids.iter();
executor
.execute_wkb_void(|maybe_geom| {
if let Some(id_opt) = id_iter.next() {
if let (Some(id), Some(geom)) = (id_opt, maybe_geom) {
result.push((id, geom))
}
}
Ok(())
})
.expect("Failed to extract geoms and ids from RecordBatch");
}
}
result
}

fn compute_knn_ground_truth(
left_partitions: &[Vec<RecordBatch>],
right_partitions: &[Vec<RecordBatch>],
k: usize,
) -> Vec<(i32, i32, f64)> {
let left_data = extract_geoms_and_ids(left_partitions);
let right_data = extract_geoms_and_ids(right_partitions);

let mut results = Vec::new();

for (l_id, l_geom) in left_data {
let mut distances: Vec<(i32, f64)> = right_data
.iter()
.map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom, r_geom)))
.collect();

// Sort by distance, then by ID for stability
distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| a.0.cmp(&b.0)));

for (r_id, dist) in distances.iter().take(k.min(distances.len())) {
results.push((l_id, *r_id, *dist));
}
}

// Sort results by L.id, R.id
results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
results
}

#[tokio::test]
async fn test_knn_join_correctness() -> Result<()> {
// Generate slightly larger data
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;

let options = SpatialJoinOptions::default();
let k = 3;

let sql1 = format!(
"SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
k
);
let expected1 = compute_knn_ground_truth(&left_partitions, &right_partitions, k)
.into_iter()
.map(|(l, r, _)| (l, r))
.collect::<Vec<_>>();

let sql2 = format!(
"SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
k
);
let expected2 = compute_knn_ground_truth(&right_partitions, &left_partitions, k)
.into_iter()
.map(|(l, r, _)| (l, r))
.collect::<Vec<_>>();

let sqls = [(&sql1, &expected1), (&sql2, &expected2)];

for (sql, expected_results) in sqls {
let batches = run_spatial_join_query(
&left_schema,
&right_schema,
left_partitions.clone(),
right_partitions.clone(),
Some(options.clone()),
10,
sql,
)
.await?;

// Collect actual results
let mut actual_results = Vec::new();
let combined_batch = arrow::compute::concat_batches(&batches.schema(), &[batches])?;
let l_ids = combined_batch
.column(0)
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
let r_ids = combined_batch
.column(1)
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();

for i in 0..combined_batch.num_rows() {
actual_results.push((l_ids.value(i), r_ids.value(i)));
}
actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));

assert_eq!(actual_results, *expected_results);
}

Ok(())
}
}
76 changes: 54 additions & 22 deletions rust/sedona-spatial-join/src/index/spatial_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion_common::Result;
use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation};
use geo_index::rtree::distance::{DistanceMetric, GeometryAccessor};
use float_next_after::NextAfter;
use geo::BoundingRect;
use geo_index::rtree::{
distance::{DistanceMetric, GeometryAccessor},
util::f64_box_to_f32,
};
use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
use geo_index::IndexableNum;
use geo_types::{Point, Rect};
use geo_types::Rect;
use parking_lot::Mutex;
use sedona_expr::statistics::GeoStatistics;
use sedona_geo::to_geo::item_to_geometry;
use sedona_geo_generic_alg::algorithm::Centroid;
use wkb::reader::Wkb;

use crate::{
Expand Down Expand Up @@ -318,25 +322,28 @@ impl SpatialIndex {

// For tie-breakers, create spatial envelope around probe centroid and use rtree.search()

let probe_centroid = probe_geom.centroid().unwrap_or(Point::new(0.0, 0.0));
let probe_x = probe_centroid.x() as f32;
let probe_y = probe_centroid.y() as f32;
let max_distance_f32 = match f32::from_f64(max_distance) {
Some(val) => val,
None => {
// If conversion fails, return empty results for this probe
return Ok(QueryResultMetrics {
count: 0,
candidate_count: 0,
});
}
// Create envelope bounds by expanding the probe bounding box by max_distance
let Some(rect) = probe_geom.bounding_rect() else {
// If bounding rectangle cannot be computed, return empty results
return Ok(QueryResultMetrics {
count: 0,
candidate_count: 0,
});
};

// Create envelope bounds around probe centroid
let min_x = probe_x - max_distance_f32;
let min_y = probe_y - max_distance_f32;
let max_x = probe_x + max_distance_f32;
let max_y = probe_y + max_distance_f32;
let min = rect.min();
let max = rect.max();
let (min_x, min_y, max_x, max_y) = f64_box_to_f32(min.x, min.y, max.x, max.y);
let mut distance_f32 = max_distance as f32;
if (distance_f32 as f64) < max_distance {
distance_f32 = distance_f32.next_after(f32::INFINITY);
}
let (min_x, min_y, max_x, max_y) = (
min_x - distance_f32,
min_y - distance_f32,
max_x + distance_f32,
max_y + distance_f32,
);

// Use rtree.search() with envelope bounds (like the old code)
let expanded_results = self.rtree.search(min_x, min_y, max_x, max_y);
Expand Down Expand Up @@ -1407,8 +1414,33 @@ mod tests {
)
.unwrap();

// Should return more than 2 results because of ties (all 4 points at distance sqrt(2))
assert!(result_with_ties.count >= 2);
// Should return 4 results because of ties (all 4 points at distance sqrt(2))
assert!(result_with_ties.count == 4);

// Query using a box centered at the origin
let query_geom = create_array(
&[Some(
"POLYGON ((-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5))",
)],
&WKB_GEOMETRY,
);
let query_array = EvaluatedGeometryArray::try_new(query_geom, &WKB_GEOMETRY).unwrap();
let query_wkb = &query_array.wkbs()[0].as_ref().unwrap();

// This query should return 4 points
let mut build_positions_with_ties = Vec::new();
let result_with_ties = index
.query_knn(
query_wkb,
2, // k=2
false, // use_spheroid
true, // include_tie_breakers=true
&mut build_positions_with_ties,
)
.unwrap();

// Should return 4 results because of ties (all 4 points at distance sqrt(2))
assert!(result_with_ties.count == 4);
}

#[test]
Expand Down
6 changes: 3 additions & 3 deletions rust/sedona-spatial-join/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ impl SpatialJoinStream {
// Extract the necessary data first to avoid borrowing conflicts
let (batch_opt, is_complete) = match &mut self.state {
SpatialJoinStreamState::ProcessProbeBatch(iterator) => {
// For KNN joins, we swapped build/probe sides, so build_side should be Right
// For regular joins, build_side is Left
// For KNN joins, we may have swapped build/probe sides, so build_side might be Right;
// For regular joins, build_side is always Left.
let build_side = match &self.spatial_predicate {
SpatialPredicate::KNearestNeighbors(_) => JoinSide::Right,
SpatialPredicate::KNearestNeighbors(knn) => knn.probe_side.negate(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When debugging my partition out of range issue with sd_random_geometry(), copilot sent me to this line and asked me to debug print the output of build_side (I should have listened!)

_ => JoinSide::Left,
};

Expand Down