Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions datafusion/physical-plan/src/joins/sort_merge_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,15 +524,20 @@ impl ExecutionPlan for SortMergeJoinExec {
}

fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
if partition.is_some() {
return Ok(Statistics::new_unknown(&self.schema()));
}
// SortMergeJoinExec uses symmetric hash partitioning where both left and right
// inputs are hash-partitioned on the join keys. This means partition `i` of the
// left input is joined with partition `i` of the right input.
//
// Therefore, partition-specific statistics can be computed by getting the
// partition-specific statistics from both children and combining them via
// `estimate_join_statistics`.
//
// TODO stats: it is not possible in general to know the output size of joins
// There are some special cases though, for example:
// - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
estimate_join_statistics(
self.left.partition_statistics(None)?,
self.right.partition_statistics(None)?,
self.left.partition_statistics(partition)?,
self.right.partition_statistics(partition)?,
&self.on,
&self.join_type,
&self.schema,
Expand Down
74 changes: 74 additions & 0 deletions datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3030,6 +3030,80 @@ async fn test_anti_join_filtered_mask() -> Result<()> {
Ok(())
}

#[test]
fn test_partition_statistics() -> Result<()> {
use crate::ExecutionPlan;
use datafusion_common::stats::Precision;

let left = build_table(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 5]),
("c1", &vec![7, 8, 9]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];

// Test different join types to ensure partition_statistics works correctly for all
let join_types = vec![
(Inner, 6), // left cols + right cols
(Left, 6), // left cols + right cols
(Right, 6), // left cols + right cols
(Full, 6), // left cols + right cols
(LeftSemi, 3), // only left cols
(LeftAnti, 3), // only left cols
(RightSemi, 3), // only right cols
(RightAnti, 3), // only right cols
];

for (join_type, expected_cols) in join_types {
let join_exec =
join(Arc::clone(&left), Arc::clone(&right), on.clone(), join_type)?;

// Test aggregate statistics (partition = None)
// Should return meaningful statistics computed from both inputs
let stats = join_exec.partition_statistics(None)?;
assert_eq!(
stats.column_statistics.len(),
expected_cols,
"Aggregate stats column count failed for {join_type:?}"
);
// Verify that aggregate statistics have a meaningful num_rows (not Absent)
assert!(
!matches!(stats.num_rows, Precision::Absent),
"Aggregate stats should have meaningful num_rows for {join_type:?}, got {:?}",
stats.num_rows
);

// Test partition-specific statistics (partition = Some(0))
// The implementation correctly passes `partition` to children.
// Since the child TestMemoryExec returns unknown stats for specific partitions,
// the join output will also have Absent num_rows. This is expected behavior
// as the statistics depend on what the children can provide.
let partition_stats = join_exec.partition_statistics(Some(0))?;
assert_eq!(
partition_stats.column_statistics.len(),
expected_cols,
"Partition stats column count failed for {join_type:?}"
);
// When children return unknown stats, the join's partition stats will be Absent
assert!(
matches!(partition_stats.num_rows, Precision::Absent),
"Partition stats should have Absent num_rows when children return unknown for {join_type:?}, got {:?}",
partition_stats.num_rows
);
}

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down