diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index 5560c29d546b3..ae7a5fa764bcc 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -524,15 +524,20 @@ impl ExecutionPlan for SortMergeJoinExec { } fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); - } + // SortMergeJoinExec uses symmetric hash partitioning where both left and right + // inputs are hash-partitioned on the join keys. This means partition `i` of the + // left input is joined with partition `i` of the right input. + // + // Therefore, partition-specific statistics can be computed by getting the + // partition-specific statistics from both children and combining them via + // `estimate_join_statistics`. + // // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` estimate_join_statistics( - self.left.partition_statistics(None)?, - self.right.partition_statistics(None)?, + self.left.partition_statistics(partition)?, + self.right.partition_statistics(partition)?, &self.on, &self.join_type, &self.schema, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 171b6e5d682ad..d0bcc79636f75 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -3030,6 +3030,80 @@ async fn test_anti_join_filtered_mask() -> Result<()> { Ok(()) } +#[test] +fn test_partition_statistics() -> Result<()> { + use crate::ExecutionPlan; + use datafusion_common::stats::Precision; + + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + // Test different join types to ensure partition_statistics works correctly for all + let join_types = vec![ + (Inner, 6), // left cols + right cols + (Left, 6), // left cols + right cols + (Right, 6), // left cols + right cols + (Full, 6), // left cols + right cols + (LeftSemi, 3), // only left cols + (LeftAnti, 3), // only left cols + (RightSemi, 3), // only right cols + (RightAnti, 3), // only right cols + ]; + + for (join_type, expected_cols) in join_types { + let join_exec = + join(Arc::clone(&left), Arc::clone(&right), on.clone(), join_type)?; + + // Test aggregate statistics (partition = None) + // Should return meaningful statistics computed from both inputs + let stats = join_exec.partition_statistics(None)?; + assert_eq!( + stats.column_statistics.len(), + expected_cols, + "Aggregate stats column count failed for {join_type:?}" + ); + // Verify that aggregate statistics have a meaningful num_rows (not Absent) + assert!( + !matches!(stats.num_rows, Precision::Absent), + "Aggregate stats should have meaningful num_rows for {join_type:?}, got {:?}", + stats.num_rows + ); + + // Test partition-specific statistics (partition = Some(0)) + // The implementation correctly passes `partition` to children. + // Since the child TestMemoryExec returns unknown stats for specific partitions, + // the join output will also have Absent num_rows. This is expected behavior + // as the statistics depend on what the children can provide. + let partition_stats = join_exec.partition_statistics(Some(0))?; + assert_eq!( + partition_stats.column_statistics.len(), + expected_cols, + "Partition stats column count failed for {join_type:?}" + ); + // When children return unknown stats, the join's partition stats will be Absent + assert!( + matches!(partition_stats.num_rows, Precision::Absent), + "Partition stats should have Absent num_rows when children return unknown for {join_type:?}, got {:?}", + partition_stats.num_rows + ); + } + + Ok(()) +} + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect()