diff --git a/python/sedonadb/tests/test_sjoin.py b/python/sedonadb/tests/test_sjoin.py index beb412ced..9169fbfb1 100644 --- a/python/sedonadb/tests/test_sjoin.py +++ b/python/sedonadb/tests/test_sjoin.py @@ -79,6 +79,173 @@ def test_spatial_join(join_type, on): eng_postgis.assert_query_result(sql, sedonadb_results) +@pytest.mark.parametrize( + "join_type", + [ + "LEFT SEMI JOIN", + "LEFT ANTI JOIN", + "RIGHT SEMI JOIN", + "RIGHT ANTI JOIN", + ], +) +@pytest.mark.parametrize( + "on", + [ + "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)", + "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)", + "ST_Contains(sjoin_polygon.geometry, sjoin_point.geometry)", + "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)", + "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, sjoin_point.dist / 100)", + "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, sjoin_polygon.dist / 100)", + ], +) +def test_spatial_join_semi_anti(join_type, on): + with ( + SedonaDB.create_or_skip() as eng_sedonadb, + PostGIS.create_or_skip() as eng_postgis, + ): + options = json.dumps( + { + "geom_type": "Point", + "polygon_hole_rate": 0.5, + "num_parts_range": [2, 10], + "vertices_per_linestring_range": [2, 10], + "seed": 42, + } + ) + df_point = eng_sedonadb.execute_and_collect( + f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + ) + options = json.dumps( + { + "geom_type": "Polygon", + "polygon_hole_rate": 0.5, + "num_parts_range": [2, 10], + "vertices_per_linestring_range": [2, 10], + "seed": 43, + } + ) + df_polygon = eng_sedonadb.execute_and_collect( + f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + ) + eng_sedonadb.create_table_arrow("sjoin_point", df_point) + eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon) + eng_postgis.create_table_arrow("sjoin_point", df_point) + eng_postgis.create_table_arrow("sjoin_polygon", df_polygon) + + is_left = join_type.startswith("LEFT") + is_semi = "SEMI" in join_type + + if is_left: + sedona_sql = f""" + SELECT sjoin_point.id id0 + FROM sjoin_point {join_type} sjoin_polygon + ON {on} + ORDER BY id0 + """ + exists = f"EXISTS (SELECT 1 FROM sjoin_polygon WHERE {on})" + where = exists if is_semi else f"NOT {exists}" + postgis_sql = f""" + SELECT sjoin_point.id id0 + FROM sjoin_point + WHERE {where} + ORDER BY id0 + """ + else: + sedona_sql = f""" + SELECT sjoin_polygon.id id1 + FROM sjoin_point {join_type} sjoin_polygon + ON {on} + ORDER BY id1 + """ + exists = f"EXISTS (SELECT 1 FROM sjoin_point WHERE {on})" + where = exists if is_semi else f"NOT {exists}" + postgis_sql = f""" + SELECT sjoin_polygon.id id1 + FROM sjoin_polygon + WHERE {where} + ORDER BY id1 + """ + + sedonadb_results = eng_sedonadb.execute_and_collect(sedona_sql).to_pandas() + assert len(sedonadb_results) > 0 + eng_postgis.assert_query_result(postgis_sql, sedonadb_results) + + +@pytest.mark.parametrize( + "outer", + ["point", "polygon"], +) +@pytest.mark.parametrize( + "on", + [ + "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)", + "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)", + "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)", + ], +) +def test_spatial_mark_join_via_correlated_exists(outer, on): + with ( + SedonaDB.create_or_skip() as eng_sedonadb, + PostGIS.create_or_skip() as eng_postgis, + ): + options = json.dumps( + { + "geom_type": "Point", + "polygon_hole_rate": 0.5, + "num_parts_range": [2, 10], + "vertices_per_linestring_range": [2, 10], + "seed": 42, + } + ) + df_point = eng_sedonadb.execute_and_collect( + f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + ) + options = json.dumps( + { + "geom_type": "Polygon", + "polygon_hole_rate": 0.5, + "num_parts_range": [2, 10], + "vertices_per_linestring_range": [2, 10], + "seed": 43, + } + ) + df_polygon = eng_sedonadb.execute_and_collect( + f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + ) + eng_sedonadb.create_table_arrow("sjoin_point", df_point) + eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon) + eng_postgis.create_table_arrow("sjoin_point", df_point) + eng_postgis.create_table_arrow("sjoin_polygon", df_polygon) + + if outer == "point": + sql = f""" + SELECT sjoin_point.id id0 + FROM sjoin_point + WHERE sjoin_point.id = 1 OR EXISTS (SELECT 1 FROM sjoin_polygon WHERE {on}) + ORDER BY id0 + """ + else: + sql = f""" + SELECT sjoin_polygon.id id1, ST_AsBinary(sjoin_polygon.geometry) geom + FROM sjoin_polygon + WHERE sjoin_polygon.id = 1 OR EXISTS (SELECT 1 FROM sjoin_point WHERE {on}) + ORDER BY id1 + """ + + # Verify the physical query plan contains a Mark join + query_plan = eng_sedonadb.execute_and_collect(f"EXPLAIN {sql}").to_pandas() + plan_text = "\n".join(query_plan.iloc[:, 1].astype(str).tolist()) + assert any( + "SpatialJoinExec" in line and ("LeftMark" in line or "RightMark" in line) + for line in plan_text.splitlines() + ), plan_text + + sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas() + assert len(sedonadb_results) > 0 + eng_postgis.assert_query_result(sql, sedonadb_results) + + @pytest.mark.parametrize( "join_type", ["INNER JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN"] ) diff --git a/rust/sedona-spatial-join/src/exec.rs b/rust/sedona-spatial-join/src/exec.rs index e3440cd1f..5cdea16de 100644 --- a/rust/sedona-spatial-join/src/exec.rs +++ b/rust/sedona-spatial-join/src/exec.rs @@ -272,7 +272,7 @@ impl SpatialJoinExec { match knn.probe_side { JoinSide::Left => left.output_partitioning().clone(), JoinSide::Right => right.output_partitioning().clone(), - _ => asymmetric_join_output_partitioning(left, right, &join_type), + _ => asymmetric_join_output_partitioning(left, right, &join_type)?, } } else if converted_from_hash_join { // Replicate HashJoin's symmetric partitioning logic @@ -290,10 +290,10 @@ impl SpatialJoinExec { // For full outer join, we can't preserve partitioning Partitioning::UnknownPartitioning(left.output_partitioning().partition_count()) } - _ => asymmetric_join_output_partitioning(left, right, &join_type), + _ => asymmetric_join_output_partitioning(left, right, &join_type)?, } } else { - asymmetric_join_output_partitioning(left, right, &join_type) + asymmetric_join_output_partitioning(left, right, &join_type)? }; if let Some(projection) = projection { @@ -615,6 +615,7 @@ mod tests { }; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_expr::ColumnarValue; + use datafusion_physical_plan::joins::NestedLoopJoinExec; use geo::{Distance, Euclidean}; use geo_types::{Coord, Rect}; use rstest::rstest; @@ -996,7 +997,7 @@ mod tests { #[rstest] #[tokio::test] async fn test_left_joins( - #[values(JoinType::Left, /* JoinType::LeftSemi, JoinType::LeftAnti */)] join_type: JoinType, + #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)] join_type: JoinType, ) -> Result<()> { test_with_join_types(join_type).await?; Ok(()) @@ -1005,8 +1006,7 @@ mod tests { #[rstest] #[tokio::test] async fn test_right_joins( - #[values(JoinType::Right, /* JoinType::RightSemi, JoinType::RightAnti */)] - join_type: JoinType, + #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)] join_type: JoinType, ) -> Result<()> { test_with_join_types(join_type).await?; Ok(()) @@ -1018,6 +1018,82 @@ mod tests { Ok(()) } + #[rstest] + #[tokio::test] + async fn test_mark_joins( + #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType, + ) -> Result<()> { + let options = SpatialJoinOptions::default(); + test_mark_join(join_type, options, 10).await?; + Ok(()) + } + + #[tokio::test] + async fn test_mark_join_via_correlated_exists_sql() -> Result<()> { + let ((left_schema, left_partitions), (right_schema, right_partitions)) = + create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?; + + let mem_table_left: Arc = Arc::new(MemTable::try_new( + left_schema.clone(), + left_partitions.clone(), + )?); + let mem_table_right: Arc = Arc::new(MemTable::try_new( + right_schema.clone(), + right_partitions.clone(), + )?); + + // DataFusion doesn't have explicit SQL syntax for MARK joins. Predicate subqueries embedded + // in a more complex boolean expression (e.g. OR) are planned using a MARK join. + // + // Using EXISTS here (rather than IN) keeps the join filter as the pulled-up correlated + // predicate (ST_Intersects), which is what SpatialJoinExec can optimize. + let sql = "SELECT L.id FROM L WHERE L.id = 1 OR EXISTS (SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY L.id"; + + let batch_size = 10; + let options = SpatialJoinOptions::default(); + + // Optimized plan should include a SpatialJoinExec with Mark join type. + let ctx = setup_context(Some(options), batch_size)?; + ctx.register_table("L", Arc::clone(&mem_table_left))?; + ctx.register_table("R", Arc::clone(&mem_table_right))?; + let df = ctx.sql(sql).await?; + let plan = df.clone().create_physical_plan().await?; + let spatial_join_execs = collect_spatial_join_exec(&plan)?; + assert!( + spatial_join_execs + .iter() + .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark | JoinType::RightMark)), + "expected correlated IN-subquery to plan using a MARK join when optimized" + ); + let actual_schema = df.schema().as_arrow().clone(); + let actual_batches = df.collect().await?; + let actual_batch = + arrow::compute::concat_batches(&Arc::new(actual_schema), &actual_batches)?; + + // Unoptimized plan should still contain a Mark join, but implemented as NestedLoopJoinExec. + let ctx_no_opt = setup_context(None, batch_size)?; + ctx_no_opt.register_table("L", mem_table_left)?; + ctx_no_opt.register_table("R", mem_table_right)?; + let df_no_opt = ctx_no_opt.sql(sql).await?; + let plan_no_opt = df_no_opt.clone().create_physical_plan().await?; + let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?; + assert!( + nlj_execs + .iter() + .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark | JoinType::RightMark)), + "expected correlated IN-subquery to plan using a MARK join when not optimized" + ); + let expected_schema = df_no_opt.schema().as_arrow().clone(); + let expected_batches = df_no_opt.collect().await?; + let expected_batch = + arrow::compute::concat_batches(&Arc::new(expected_schema), &expected_batches)?; + + assert!(expected_batch.num_rows() > 0); + assert_eq!(expected_batch, actual_batch); + + Ok(()) + } + #[tokio::test] async fn test_geography_join_is_not_optimized() -> Result<()> { let options = SpatialJoinOptions::default(); @@ -1075,10 +1151,10 @@ mod tests { JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id", JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id", JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id", - JoinType::LeftSemi => "SELECT L.id l_id FROM L WHERE EXISTS (SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id", - JoinType::RightSemi => "SELECT R.id r_id FROM R WHERE EXISTS (SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id", - JoinType::LeftAnti => "SELECT L.id l_id FROM L WHERE NOT EXISTS (SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id", - JoinType::RightAnti => "SELECT R.id r_id FROM R WHERE NOT EXISTS (SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id", + JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id", + JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id", + JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id", + JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id", JoinType::LeftMark => { unreachable!("LeftMark is not directly supported in SQL, will be tested in other tests"); } @@ -1203,6 +1279,93 @@ mod tests { Ok(spatial_join_execs) } + fn collect_nested_loop_join_exec( + plan: &Arc, + ) -> Result> { + let mut execs = Vec::new(); + plan.apply(|node| { + if let Some(exec) = node.as_any().downcast_ref::() { + execs.push(exec); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(execs) + } + + async fn test_mark_join( + join_type: JoinType, + options: SpatialJoinOptions, + batch_size: usize, + ) -> Result<()> { + let ((left_schema, left_partitions), (right_schema, right_partitions)) = + create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?; + let mem_table_left: Arc = Arc::new(MemTable::try_new( + left_schema.clone(), + left_partitions.clone(), + )?); + let mem_table_right: Arc = Arc::new(MemTable::try_new( + right_schema.clone(), + right_partitions.clone(), + )?); + + // We use a Left Join as a template to create the plan, then modify it to Mark Join + let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry, R.geometry)"; + + // Create SpatialJoinExec plan + let ctx = setup_context(Some(options), batch_size)?; + ctx.register_table("L", mem_table_left.clone())?; + ctx.register_table("R", mem_table_right.clone())?; + let df = ctx.sql(sql).await?; + let plan = df.create_physical_plan().await?; + let spatial_join_execs = collect_spatial_join_exec(&plan)?; + assert_eq!(spatial_join_execs.len(), 1); + let original_exec = spatial_join_execs[0]; + let mark_exec = SpatialJoinExec::try_new( + original_exec.left.clone(), + original_exec.right.clone(), + original_exec.on.clone(), + original_exec.filter.clone(), + &join_type, + None, + )?; + + // Create NestedLoopJoinExec plan for comparison + let ctx_no_opt = setup_context(None, batch_size)?; + ctx_no_opt.register_table("L", mem_table_left)?; + ctx_no_opt.register_table("R", mem_table_right)?; + let df_no_opt = ctx_no_opt.sql(sql).await?; + let plan_no_opt = df_no_opt.create_physical_plan().await?; + let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?; + assert_eq!(nlj_execs.len(), 1); + let original_nlj = nlj_execs[0]; + let mark_nlj = NestedLoopJoinExec::try_new( + original_nlj.children()[0].clone(), + original_nlj.children()[1].clone(), + original_nlj.filter().cloned(), + &join_type, + None, + )?; + + async fn run_and_sort( + plan: Arc, + ctx: &SessionContext, + ) -> Result { + let results = datafusion_physical_plan::collect(plan, ctx.task_ctx()).await?; + let batch = arrow::compute::concat_batches(&results[0].schema(), &results)?; + let sort_col = batch.column(0); + let indices = arrow::compute::sort_to_indices(sort_col, None, None)?; + let sorted_batch = arrow::compute::take_record_batch(&batch, &indices)?; + Ok(sorted_batch) + } + + // Run both Mark Join plans and compare results + let mark_batch = run_and_sort(Arc::new(mark_exec), &ctx).await?; + let mark_nlj_batch = run_and_sort(Arc::new(mark_nlj), &ctx_no_opt).await?; + assert_eq!(mark_batch, mark_nlj_batch); + + Ok(()) + } + fn extract_geoms_and_ids(partitions: &[Vec]) -> Vec<(i32, geo::Geometry)> { let mut result = Vec::new(); for partition in partitions { diff --git a/rust/sedona-spatial-join/src/stream.rs b/rust/sedona-spatial-join/src/stream.rs index 7fa422312..37a84523d 100644 --- a/rust/sedona-spatial-join/src/stream.rs +++ b/rust/sedona-spatial-join/src/stream.rs @@ -698,6 +698,7 @@ impl SpatialJoinBatchIterator { &probe_indices, column_indices, build_side, + join_type, )?; // Update metrics with actual output @@ -896,6 +897,7 @@ impl UnmatchedBuildBatchIterator { &right_side, column_indices, build_side, + join_type, )? }; diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs b/rust/sedona-spatial-join/src/utils/join_utils.rs index 83ec18f49..87aaa9ae8 100644 --- a/rust/sedona-spatial-join/src/utils/join_utils.rs +++ b/rust/sedona-spatial-join/src/utils/join_utils.rs @@ -16,14 +16,15 @@ // under the License. /// Most of the code in this module are copied from the `datafusion_physical_plan::joins::utils` module. -/// https://github.com/apache/datafusion/blob/48.0.0/datafusion/physical-plan/src/joins/utils.rs +/// https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs use std::{ops::Range, sync::Arc}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, RecordBatch, RecordBatchOptions, UInt32Builder, UInt64Builder, }; -use arrow::compute; +use arrow::buffer::NullBuffer; +use arrow::compute::{self, take}; use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type}; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, UInt32Array, UInt64Array}; use datafusion_common::cast::as_boolean_array; @@ -112,6 +113,7 @@ pub(crate) fn apply_join_filter_to_indices( &probe_indices, filter.column_indices(), build_side, + JoinType::Inner, )?; let filter_result = filter .expression() @@ -129,6 +131,7 @@ pub(crate) fn apply_join_filter_to_indices( /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. +#[allow(clippy::too_many_arguments)] pub(crate) fn build_batch_from_indices( schema: &Schema, build_input_buffer: &RecordBatch, @@ -137,6 +140,7 @@ pub(crate) fn build_batch_from_indices( probe_indices: &UInt32Array, column_indices: &[ColumnIndex], build_side: JoinSide, + join_type: JoinType, ) -> Result { if schema.fields().is_empty() { let options = RecordBatchOptions::new() @@ -157,8 +161,12 @@ pub(crate) fn build_batch_from_indices( for column_index in column_indices { let array = if column_index.side == JoinSide::None { - // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false - Arc::new(compute::is_not_null(probe_indices)?) + // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false + if join_type == JoinType::RightMark { + Arc::new(compute::is_not_null(build_indices)?) + } else { + Arc::new(compute::is_not_null(probe_indices)?) + } } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); if array.is_empty() || build_indices.null_count() == build_indices.len() { @@ -168,7 +176,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(build_indices.null_count(), build_indices.len()); new_null_array(array.data_type(), build_indices.len()) } else { - compute::take(array.as_ref(), build_indices, None)? + take(array.as_ref(), build_indices, None)? } } else { let array = probe_batch.column(column_index.index); @@ -176,9 +184,10 @@ pub(crate) fn build_batch_from_indices( assert_eq!(probe_indices.null_count(), probe_indices.len()); new_null_array(array.data_type(), probe_indices.len()) } else { - compute::take(array.as_ref(), probe_indices, None)? + take(array.as_ref(), probe_indices, None)? } }; + columns.push(array); } Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) @@ -226,7 +235,12 @@ pub(crate) fn adjust_indices_by_join_type( // the left_indices will not be used later for the `right anti` join Ok((left_indices, right_indices)) } - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightMark => { + JoinType::RightMark => { + let new_left_indices = get_mark_indices(&adjust_range, &right_indices); + let new_right_indices = adjust_range.map(|i| i as u32).collect(); + Ok((new_left_indices, new_right_indices)) + } + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop Ok(( @@ -328,17 +342,7 @@ pub(crate) fn get_anti_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; // get the anti index @@ -355,25 +359,52 @@ pub(crate) fn get_semi_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; - // get the semi index (range) .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))) .collect() } +/// Returns an array for mark joins consisting of default values (zeros) with null/non-null markers. +/// +/// For each index in `range`: +/// - If the index appears in `input_indices`, the value is non-null (0) +/// - If the index does not appear in `input_indices`, the value is null +/// +/// This is used in mark joins to indicate which rows had matches. +pub(crate) fn get_mark_indices( + range: &Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = build_range_bitmap(range, input_indices); + PrimitiveArray::new( + vec![R::Native::default(); range.len()].into(), + Some(NullBuffer::new(bitmap.finish())), + ) +} + +fn build_range_bitmap( + range: &Range, + input: &PrimitiveArray, +) -> BooleanBufferBuilder { + let mut builder = BooleanBufferBuilder::new(range.len()); + builder.append_n(range.len(), false); + + input.iter().flatten().for_each(|v| { + let idx = v.as_usize(); + if range.contains(&idx) { + builder.set_bit(idx - range.start, true); + } + }); + + builder +} + /// Appends probe indices in order by considering the given build indices. /// /// This function constructs new build and probe indices by iterating through @@ -432,23 +463,24 @@ pub(crate) fn asymmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { - match join_type { +) -> Result { + let result = match join_type { JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields().len(), - ) - .unwrap_or_else(|_| Partitioning::UnknownPartitioning(1)), - JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + )?, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.output_partitioning().clone() + } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full - | JoinType::LeftMark - | JoinType::RightMark => { + | JoinType::LeftMark => { Partitioning::UnknownPartitioning(right.output_partitioning().partition_count()) } - } + }; + Ok(result) } /// This function is copied from