diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index c7f0b5a4f4881..f97276e3c3761 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -30,7 +30,7 @@ use std::fmt; pub struct Column { /// relation/table reference. pub relation: Option, - /// field/column name. + /// Field/column name. pub name: String, /// Original source code location, if known pub spans: Spans, diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index d6357fdf6bc7d..2f49ac19d01ab 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -64,15 +64,13 @@ use datafusion_physical_plan::{ sorts::sort::SortExec, }; +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; use datafusion_physical_plan::union::UnionExec; use futures::StreamExt; use object_store::{ObjectStore, memory::InMemory}; use regex::Regex; -use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test}; - -use crate::physical_optimizer::filter_pushdown::util::TestSource; - -mod util; #[test] fn test_pushdown_into_scan() { diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index d11322cd26be9..cf179cb727cf1 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -24,7 +24,6 @@ mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; -#[expect(clippy::needless_pass_by_value)] mod filter_pushdown; mod join_selection; #[expect(clippy::needless_pass_by_value)] @@ -38,3 +37,5 @@ mod sanity_checker; #[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 480f5c8cc97b1..5cd0c356ee391 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -18,12 +18,15 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::{Int32Array, RecordBatch, StructArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::Fields; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; +use datafusion::prelude::get_field; use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; use datafusion_datasource::TableSchema; @@ -31,12 +34,13 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ - Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, lit, }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, }; +use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -65,6 +69,8 @@ use datafusion_physical_plan::{ExecutionPlan, displayable}; use insta::assert_snapshot; use itertools::Itertools; +use crate::physical_optimizer::pushdown_utils::TestScanBuilder; + /// Mocked UDF #[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { @@ -1778,3 +1784,87 @@ fn test_cooperative_exec_after_projection() -> Result<()> { Ok(()) } + +#[test] +fn test_pushdown_projection_through_repartition_filter() { + let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, false)]); + let array = StructArray::new( + struct_fields.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + None, + ); + let batches = vec![ + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct(struct_fields.clone()), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct(struct_fields), + true, + )])); + + let scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(batches) + .build(); + let scan_schema = scan.schema(); + let struct_access = get_field(datafusion_expr::col("struct"), "a"); + let filter = struct_access.clone().gt(lit(2)); + let repartition = + RepartitionExec::try_new(scan, Partitioning::RoundRobinBatch(32)).unwrap(); + let filter_exec = FilterExec::try_new( + logical2physical(&filter, &scan_schema), + Arc::new(repartition), + ) + .unwrap(); + let projection: Arc = Arc::new( + ProjectionExec::try_new( + vec![ProjectionExpr::new( + logical2physical(&struct_access, &scan_schema), + "a", + )], + Arc::new(filter_exec), + ) + .unwrap(), + ) as _; + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[get_field(struct@0, a) as a] + FilterExec: get_field(struct@0, a) > 2 + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[struct], file_type=test, pushdown_supported=true + " + ); + + let after_optimize = ProjectionPushdown::new() + .optimize(projection, &ConfigOptions::new()) + .unwrap(); + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed all the way down to the DataSource, and + // filter predicate should be rewritten to reference projection's output column + assert_snapshot!( + actual, + @r" + FilterExec: a@0 > 2 + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[get_field(struct@0, a) as a], file_type=test, pushdown_supported=true + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 92% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index 1afdc4823f0a4..3708e5b696a89 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -24,6 +24,7 @@ use datafusion_datasource::{ file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -50,7 +51,7 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - projection: Option>, + projection: Option, predicate: Option>, } @@ -60,6 +61,7 @@ impl FileOpener for TestOpener { if self.batches.is_empty() { return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -83,9 +85,10 @@ impl FileOpener for TestOpener { batches = new_batches; if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } @@ -103,14 +106,14 @@ pub struct TestSource { batch_size: Option, batches: Vec, metrics: ExecutionPlanMetricsSet, - projection: Option>, + projection: Option, table_schema: datafusion_datasource::TableSchema, } impl TestSource { pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { let table_schema = - datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); + datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), @@ -210,6 +213,30 @@ impl FileSource for TestSource { } } + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + fn table_schema(&self) -> &datafusion_datasource::TableSchema { &self.table_schema } @@ -332,6 +359,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 75cd78e47aff5..bdee931972c4b 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -997,10 +997,9 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec - ProjectionExec: expr=[id@0 as id, 1 as level] - FilterExec: id@0 = 1 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + FilterExec: id@0 = level@1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id, 1 as level], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] CoalescePartitionsExec ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] FilterExec: id@0 < 10 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7d825ce1d52f..d15edd5ae4527 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1883,6 +1883,28 @@ impl Expr { } } + /// Returns true if this expression is trivial (cheap to evaluate). + /// + /// Trivial expressions include column references, literals, and nested + /// field access via `get_field`. + /// + /// # Example + /// ``` + /// # use datafusion_expr::col; + /// let expr = col("foo"); + /// assert!(expr.is_trivial()); + /// ``` + pub fn is_trivial(&self) -> bool { + match self { + Expr::Column(_) | Expr::Literal(_, _) => true, + Expr::ScalarFunction(func) => { + func.func.is_trivial() + && func.args.first().is_some_and(|arg| arg.is_trivial()) + } + _ => false, + } + } + /// Return all references to columns in this expression. /// /// # Example diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 26d7fc99cb17c..9643a821b7bbb 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -122,6 +122,11 @@ impl ScalarUDF { Self { inner: fun } } + /// Returns true if this function is trivial (cheap to evaluate). + pub fn is_trivial(&self) -> bool { + self.inner.is_trivial() + } + /// Return the underlying [`ScalarUDFImpl`] trait object for this function pub fn inner(&self) -> &Arc { &self.inner @@ -846,6 +851,18 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns true if this function is trivial (cheap to evaluate). + /// + /// Trivial functions are lightweight accessor functions like `get_field` + /// (struct field access) that simply access nested data within a column + /// without significant computation. + /// + /// This is used to identify expressions that are cheap to duplicate or + /// don't benefit from caching/partitioning optimizations. + fn is_trivial(&self) -> bool { + false + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -964,6 +981,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn is_trivial(&self) -> bool { + self.inner.is_trivial() + } } #[cfg(test)] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3e961e4da4e75..7c0df516ed599 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -499,6 +499,10 @@ impl ScalarUDFImpl for GetFieldFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn is_trivial(&self) -> bool { + true + } } #[cfg(test)] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 548eadffa242e..200f3f799287b 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -535,10 +535,8 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], - ) + && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()] + .is_trivial() }) { // no change return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); @@ -591,11 +589,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a21940912..5dc12693ea250 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -430,6 +430,20 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn is_volatile_node(&self) -> bool { false } + + /// Returns true if this expression is trivial (cheap to evaluate). + /// + /// Trivial expressions include: + /// - Column references + /// - Literal values + /// - Struct field access via `get_field` + /// - Nested combinations of field accessors (e.g., `col['a']['b']`) + /// + /// This is used to identify expressions that are cheap to duplicate or + /// don't benefit from caching/partitioning optimizations. + fn is_trivial(&self) -> bool { + false + } } #[deprecated( diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8c7e8c319fff4..86dcc6fa87752 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -146,6 +146,10 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn is_trivial(&self) -> bool { + true + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7ad..6aaa3b0c77575 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -134,6 +134,10 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn is_trivial(&self) -> bool { + true + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 8d4afb5d19701..49e1a9468f11f 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -37,6 +37,7 @@ use datafusion_physical_expr_common::metrics::ExpressionEvaluatorMetrics; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays_with_metrics; +use hashbrown::HashSet; use indexmap::IndexMap; use itertools::Itertools; @@ -233,6 +234,11 @@ impl ProjectionExprs { self.exprs.iter() } + /// Checks if all of the projection expressions are trivial. + pub fn is_trivial(&self) -> bool { + self.exprs.iter().all(|p| p.expr.is_trivial()) + } + /// Creates a ProjectionMapping from this projection pub fn projection_mapping( &self, @@ -747,7 +753,7 @@ pub fn update_expr( projected_exprs: &[ProjectionExpr], sync_with_child: bool, ) -> Result>> { - #[derive(Debug, PartialEq)] + #[derive(PartialEq)] enum RewriteState { /// The expression is unchanged. Unchanged, @@ -758,10 +764,46 @@ pub fn update_expr( RewrittenInvalid, } + // Track columns introduced by pass 1 (by name and index). + // These should not be modified by pass 2. + let mut valid_columns = HashSet::new(); + + // First pass: try to rewrite the expression in terms of the projected expressions. + // For example, if the expression is `a + b > 5` and the projection is `a + b AS sum_ab`, + // we can rewrite the expression to `sum_ab > 5` directly. + // + // This optimization only applies when sync_with_child=false, meaning we want the + // expression to use OUTPUT references (e.g., when pushing projection down and the + // expression will be above the projection). Pass 1 creates OUTPUT column references. + // + // When sync_with_child=true, we want INPUT references (expanding OUTPUT to INPUT), + // so pass 1 doesn't apply. + let new_expr = if !sync_with_child { + Arc::clone(expr) + .transform_down(&mut |expr: Arc| { + // If expr is equal to one of the projected expressions, we can short-circuit the rewrite: + for (idx, projected_expr) in projected_exprs.iter().enumerate() { + if expr.eq(&projected_expr.expr) { + // Track this column so pass 2 doesn't modify it + valid_columns.insert((projected_expr.alias.clone(), idx)); + return Ok(Transformed::yes(Arc::new(Column::new( + &projected_expr.alias, + idx, + )) as _)); + } + } + Ok(Transformed::no(expr)) + })? + .data + } else { + Arc::clone(expr) + }; + + // Second pass: rewrite remaining column references based on the projection. + // Skip columns that were introduced by pass 1. let mut state = RewriteState::Unchanged; - - let new_expr = Arc::clone(expr) - .transform_up(|expr| { + let new_expr = new_expr + .transform_up(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -769,6 +811,15 @@ pub fn update_expr( let Some(column) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; + + // Skip columns introduced by pass 1 - they're already valid OUTPUT references. + // Mark state as valid since pass 1 successfully handled this column. + if valid_columns.contains(&(column.name().to_string(), column.index())) + { + state = RewriteState::RewrittenValid; + return Ok(Transformed::no(expr)); + } + if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: @@ -2377,6 +2428,240 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_update_expr_matches_projected_expr() -> Result<()> { + // Test that when filter expression exactly matches a projected expression, + // update_expr short-circuits and rewrites to use the projected column. + // e.g., projection: a * 2 AS a_times_2, filter: a * 2 > 4 + // should become: a_times_2 > 4 + + // Create the computed expression: a@0 * 2 + let computed_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Multiply, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )); + + // Create projection with the computed expression aliased as "a_times_2" + let projection = vec![ProjectionExpr { + expr: Arc::clone(&computed_expr), + alias: "a_times_2".to_string(), + }]; + + // Create filter predicate: a * 2 > 4 (same expression as projection) + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&computed_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), + )); + + // Update the expression - should rewrite a * 2 to a_times_2@0 + // sync_with_child=false because we want OUTPUT references (filter will be above projection) + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + let binary = result_expr + .as_any() + .downcast_ref::() + .expect("Should be a BinaryExpr"); + // Left side should now be a column reference to a_times_2@0 + let left_col = binary + .left() + .as_any() + .downcast_ref::() + .expect("Left should be rewritten to a Column"); + assert_eq!(left_col.name(), "a_times_2"); + assert_eq!(left_col.index(), 0); + + // Right side should still be the literal 4 + let right_lit = binary + .right() + .as_any() + .downcast_ref::() + .expect("Right should be a Literal"); + assert_eq!(right_lit.value(), &ScalarValue::Int32(Some(4))); + + Ok(()) + } + + #[test] + fn test_update_expr_partial_match() -> Result<()> { + // Test that when only part of an expression matches, we still handle + // the rest correctly. e.g., `a + b > 2 AND c > 3` with projection + // `a + b AS sum_ab, c AS c_out` should become `sum_ab > 2 AND c_out > 3` + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab, c AS c_out] + let projection = vec![ + ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_out".to_string(), + }, + ]; + + // Filter: (a + b > 2) AND (c > 3) + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + )); + + // With sync_with_child=false: columns reference input schema, need to map to output + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + // Should be: sum_ab@0 > 2 AND c_out@1 > 3 + assert_eq!(result_expr.to_string(), "sum_ab@0 > 2 AND c_out@1 > 3"); + + Ok(()) + } + + #[test] + fn test_update_expr_partial_match_with_unresolved_column() -> Result<()> { + // Test that when part of an expression matches but other columns can't be + // resolved, we return None. e.g., `a + b > 2 AND c > 3` with projection + // `a + b AS sum_ab` (note: no 'c' column!) should return None. + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab] - note: NO 'c' column! + let projection = vec![ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }]; + + // Filter: (a + b > 2) AND (c > 3) - 'c' is not in projection! + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + )); + + // With sync_with_child=false: should return None because 'c' can't be mapped + let result = update_expr(&filter_predicate, &projection, false)?; + assert!( + result.is_none(), + "Should return None when some columns can't be resolved" + ); + + // On the other hand if the projection is `c AS c_out, a + b AS sum_ab` we should succeed + let projection = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_out".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }, + ]; + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid now"); + + Ok(()) + } + + #[test] + fn test_update_expr_nested_match() -> Result<()> { + // Test matching a sub-expression within a larger expression. + // e.g., `(a + b) * 2 > 10` with projection `a + b AS sum_ab` + // should become `sum_ab * 2 > 10` + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab] + let projection = vec![ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }]; + + // Filter: (a + b) * 2 > 10 + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Multiply, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // With sync_with_child=false: should rewrite a+b to sum_ab + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + // Should be: sum_ab@0 * 2 > 10 + assert_eq!(result_expr.to_string(), "sum_ab@0 * 2 > 10"); + + Ok(()) + } + + #[test] + fn test_update_expr_no_match_returns_none() -> Result<()> { + // Test that when columns can't be resolved, we return None (with sync_with_child=false) + + // Projection: [a AS a_out] + let projection = vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a_out".to_string(), + }]; + + // Filter references column 'd' which is not in projection + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), // Not in projection + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // With sync_with_child=false: should return None because 'd' can't be mapped + let result = update_expr(&filter_predicate, &projection, false)?; + assert!( + result.is_none(), + "Should return None when column can't be resolved" + ); + + Ok(()) + } + #[test] fn test_project_schema_simple_columns() -> Result<()> { // Input schema: [col0: Int64, col1: Utf8, col2: Float32] diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index e6a6db75bebd7..8959bf57aa246 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -366,6 +366,13 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn is_trivial(&self) -> bool { + if !self.fun.is_trivial() { + return false; + } + self.args.iter().all(|arg| arg.is_trivial()) + } } #[cfg(test)] diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0dc6a25fbc0b7..d1fbce1d21d7f 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -256,9 +256,8 @@ impl ExecutionPlan for OutputRequirementExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - let proj_exprs = projection.expr(); - if proj_exprs.len() >= projection.input().schema().fields().len() { + // If the projection is not trivial, we should not try to push it down + if !projection.projection_expr().is_trivial() { return Ok(None); } @@ -267,7 +266,8 @@ impl ExecutionPlan for OutputRequirementExec { let mut updated_reqs = vec![]; let (lexes, soft) = reqs.into_alternatives(); for lex in lexes.into_iter() { - let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? + let Some(updated_lex) = + update_ordering_requirement(lex, projection.expr())? else { return Ok(None); }; diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index d83f90eb3d8c1..8e16ae2d17eee 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -249,8 +249,8 @@ impl ExecutionPlan for CoalescePartitionsExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + // If the projection is not trivial, we should not try to push it down + if !projection.projection_expr().is_trivial() { return Ok(None); } // CoalescePartitionsExec always has a single child, so zero indexing is safe. diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 674fe6692adf5..0e177d1e6cce7 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -467,8 +467,8 @@ impl ExecutionPlan for FilterExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() < projection.input().schema().fields().len() { + // If the projection is not trivial, we should not try to push it down + if projection.projection_expr().is_trivial() { // Each column in the predicate expression must exist after the projection. if let Some(new_predicate) = update_expr(self.predicate(), projection.expr(), false)? diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index e8608f17a1b20..2fce69cfd0254 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,7 +20,7 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -255,18 +255,10 @@ impl ExecutionPlan for ProjectionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - let all_simple_exprs = - self.projector - .projection() - .as_ref() - .iter() - .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() - }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, - // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. - vec![!all_simple_exprs] + // If expressions are all trivial (columns, literals, or field accessors), + // then all computations in this projection are reorder or rename, + // and projection would not benefit from the repartition. + vec![!self.projection_expr().is_trivial()] } fn children(&self) -> Vec<&Arc> { @@ -700,13 +692,6 @@ pub fn make_with_child( .map(|e| Arc::new(e) as _) } -/// Returns `true` if all the expressions in the argument are `Column`s. -pub fn all_columns(exprs: &[ProjectionExpr]) -> bool { - exprs - .iter() - .all(|proj_expr| proj_expr.expr.as_any().is::()) -} - /// Updates the given lexicographic ordering according to given projected /// expressions using the [`update_expr`] function. pub fn update_ordering( @@ -949,7 +934,7 @@ fn try_unifying_projections( // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) + *count > 1 && !&child.expr()[column.index()].expr.is_trivial() }) { return Ok(None); } @@ -1059,13 +1044,6 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 1efdaaabc7d6a..cb609f656c6ae 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,7 +34,7 @@ use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; -use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr}; +use crate::projection::{ProjectionExec, make_with_child, update_expr}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::spill::spill_pool::{self, SpillPoolWriter}; @@ -1047,14 +1047,9 @@ impl ExecutionPlan for RepartitionExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { - return Ok(None); - } - // If pushdown is not beneficial or applicable, break it. if projection.benefits_from_input_partitioning()[0] - || !all_columns(projection.expr()) + || !projection.projection_expr().is_trivial() { return Ok(None); } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7e..088f66dfc76c2 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1391,8 +1391,8 @@ impl ExecutionPlan for SortExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // If the projection is not trivial, we should not try to push it down + if !projection.projection_expr().is_trivial() { return Ok(None); } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 4b26f84099505..93828897c0723 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -378,8 +378,8 @@ impl ExecutionPlan for SortPreservingMergeExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // If the projection is not trivial, we should not try to push it down + if !projection.projection_expr().is_trivial() { return Ok(None); } diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 352056adbf813..8cb10909096de 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -659,8 +659,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 05)--------UnnestExec 06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------DataSourceExec: partitions=1, partition_sizes=[1]