@@ -31,7 +31,7 @@ use datafusion_common::{
3131 assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
3232 qualified_name, Column , DFSchema , DataFusionError , Result ,
3333} ;
34- use datafusion_expr:: expr:: WindowFunction ;
34+ use datafusion_expr:: expr:: { Between , InList , ScalarFunction , WindowFunction } ;
3535use datafusion_expr:: expr_rewriter:: replace_col;
3636use datafusion_expr:: logical_plan:: { Join , JoinType , LogicalPlan , TableScan , Union } ;
3737use datafusion_expr:: utils:: {
@@ -418,6 +418,204 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
418418 predicate
419419}
420420
421+ /// Tracks coalesce predicates that can be pushed to each side of a FULL JOIN.
422+ struct PushDownCoalesceFilterHelper {
423+ join_keys : Vec < ( Column , Column ) > ,
424+ left_filters : Vec < Expr > ,
425+ right_filters : Vec < Expr > ,
426+ remaining_filters : Vec < Expr > ,
427+ }
428+
429+ impl PushDownCoalesceFilterHelper {
430+ fn new ( join_keys : & [ ( Expr , Expr ) ] ) -> Self {
431+ let join_keys = join_keys
432+ . iter ( )
433+ . filter_map ( |( lhs, rhs) | {
434+ Some ( ( lhs. try_as_col ( ) ?. clone ( ) , rhs. try_as_col ( ) ?. clone ( ) ) )
435+ } )
436+ . collect ( ) ;
437+ Self {
438+ join_keys,
439+ left_filters : Vec :: new ( ) ,
440+ right_filters : Vec :: new ( ) ,
441+ remaining_filters : Vec :: new ( ) ,
442+ }
443+ }
444+
445+ fn push_columns < F : FnMut ( Expr ) -> Expr > (
446+ & mut self ,
447+ columns : ( Column , Column ) ,
448+ mut build_filter : F ,
449+ ) {
450+ self . left_filters
451+ . push ( build_filter ( Expr :: Column ( columns. 0 ) ) ) ;
452+ self . right_filters
453+ . push ( build_filter ( Expr :: Column ( columns. 1 ) ) ) ;
454+ }
455+
456+ fn extract_join_columns ( & self , expr : & Expr ) -> Option < ( Column , Column ) > {
457+ if let Expr :: ScalarFunction ( ScalarFunction { func, args } ) = expr {
458+ if func. name ( ) != "coalesce" {
459+ return None ;
460+ }
461+ if let [ Expr :: Column ( lhs) , Expr :: Column ( rhs) ] = args. as_slice ( ) {
462+ for ( join_lhs, join_rhs) in & self . join_keys {
463+ if join_lhs == lhs && join_rhs == rhs {
464+ return Some ( ( lhs. clone ( ) , rhs. clone ( ) ) ) ;
465+ }
466+ if join_lhs == rhs && join_rhs == lhs {
467+ return Some ( ( rhs. clone ( ) , lhs. clone ( ) ) ) ;
468+ }
469+ }
470+ }
471+ }
472+ None
473+ }
474+
475+ fn push_term ( & mut self , term : & Expr ) {
476+ match term {
477+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
478+ if op. supports_propagation ( ) =>
479+ {
480+ if let Some ( columns) = self . extract_join_columns ( left) {
481+ return self . push_columns ( columns, |replacement| {
482+ Expr :: BinaryExpr ( BinaryExpr {
483+ left : Box :: new ( replacement) ,
484+ op : * op,
485+ right : right. clone ( ) ,
486+ } )
487+ } ) ;
488+ }
489+ if let Some ( columns) = self . extract_join_columns ( right) {
490+ return self . push_columns ( columns, |replacement| {
491+ Expr :: BinaryExpr ( BinaryExpr {
492+ left : left. clone ( ) ,
493+ op : * op,
494+ right : Box :: new ( replacement) ,
495+ } )
496+ } ) ;
497+ }
498+ }
499+ Expr :: IsNull ( expr) => {
500+ if let Some ( columns) = self . extract_join_columns ( expr) {
501+ return self . push_columns ( columns, |replacement| {
502+ Expr :: IsNull ( Box :: new ( replacement) )
503+ } ) ;
504+ }
505+ }
506+ Expr :: IsNotNull ( expr) => {
507+ if let Some ( columns) = self . extract_join_columns ( expr) {
508+ return self . push_columns ( columns, |replacement| {
509+ Expr :: IsNotNull ( Box :: new ( replacement) )
510+ } ) ;
511+ }
512+ }
513+ Expr :: IsTrue ( expr) => {
514+ if let Some ( columns) = self . extract_join_columns ( expr) {
515+ return self . push_columns ( columns, |replacement| {
516+ Expr :: IsTrue ( Box :: new ( replacement) )
517+ } ) ;
518+ }
519+ }
520+ Expr :: IsFalse ( expr) => {
521+ if let Some ( columns) = self . extract_join_columns ( expr) {
522+ return self . push_columns ( columns, |replacement| {
523+ Expr :: IsFalse ( Box :: new ( replacement) )
524+ } ) ;
525+ }
526+ }
527+ Expr :: IsUnknown ( expr) => {
528+ if let Some ( columns) = self . extract_join_columns ( expr) {
529+ return self . push_columns ( columns, |replacement| {
530+ Expr :: IsUnknown ( Box :: new ( replacement) )
531+ } ) ;
532+ }
533+ }
534+ Expr :: IsNotTrue ( expr) => {
535+ if let Some ( columns) = self . extract_join_columns ( expr) {
536+ return self . push_columns ( columns, |replacement| {
537+ Expr :: IsNotTrue ( Box :: new ( replacement) )
538+ } ) ;
539+ }
540+ }
541+ Expr :: IsNotFalse ( expr) => {
542+ if let Some ( columns) = self . extract_join_columns ( expr) {
543+ return self . push_columns ( columns, |replacement| {
544+ Expr :: IsNotFalse ( Box :: new ( replacement) )
545+ } ) ;
546+ }
547+ }
548+ Expr :: IsNotUnknown ( expr) => {
549+ if let Some ( columns) = self . extract_join_columns ( expr) {
550+ return self . push_columns ( columns, |replacement| {
551+ Expr :: IsNotUnknown ( Box :: new ( replacement) )
552+ } ) ;
553+ }
554+ }
555+ Expr :: Between ( between) => {
556+ if let Some ( columns) = self . extract_join_columns ( & between. expr ) {
557+ return self . push_columns ( columns, |replacement| {
558+ Expr :: Between ( Between {
559+ expr : Box :: new ( replacement) ,
560+ negated : between. negated ,
561+ low : between. low . clone ( ) ,
562+ high : between. high . clone ( ) ,
563+ } )
564+ } ) ;
565+ }
566+ }
567+ Expr :: InList ( in_list) => {
568+ if let Some ( columns) = self . extract_join_columns ( & in_list. expr ) {
569+ return self . push_columns ( columns, |replacement| {
570+ Expr :: InList ( InList {
571+ expr : Box :: new ( replacement) ,
572+ list : in_list. list . clone ( ) ,
573+ negated : in_list. negated ,
574+ } )
575+ } ) ;
576+ }
577+ }
578+ _ => { }
579+ }
580+ self . remaining_filters . push ( term. clone ( ) ) ;
581+ }
582+
583+ fn push_predicate (
584+ mut self ,
585+ predicate : Expr ,
586+ ) -> Result < ( Option < Expr > , Option < Expr > , Vec < Expr > ) > {
587+ let predicates = split_conjunction_owned ( predicate) ;
588+ let terms = simplify_predicates ( predicates) ?;
589+ for term in terms {
590+ self . push_term ( & term) ;
591+ }
592+ Ok ( (
593+ conjunction ( self . left_filters ) ,
594+ conjunction ( self . right_filters ) ,
595+ self . remaining_filters ,
596+ ) )
597+ }
598+ }
599+
600+ fn push_full_join_coalesce_filters (
601+ join : & mut Join ,
602+ predicate : Expr ,
603+ ) -> Result < Option < Vec < Expr > > > {
604+ let ( Some ( left) , Some ( right) , remaining) =
605+ PushDownCoalesceFilterHelper :: new ( & join. on ) . push_predicate ( predicate) ?
606+ else {
607+ return Ok ( None ) ;
608+ } ;
609+
610+ let left_input = Arc :: clone ( & join. left ) ;
611+ join. left = Arc :: new ( make_filter ( left, left_input) ?) ;
612+
613+ let right_input = Arc :: clone ( & join. right ) ;
614+ join. right = Arc :: new ( make_filter ( right, right_input) ?) ;
615+
616+ Ok ( Some ( remaining) )
617+ }
618+
421619/// push down join/cross-join
422620fn push_down_all_join (
423621 predicates : Vec < Expr > ,
@@ -527,13 +725,21 @@ fn push_down_all_join(
527725}
528726
529727fn push_down_join (
530- join : Join ,
728+ mut join : Join ,
531729 parent_predicate : Option < & Expr > ,
532730) -> Result < Transformed < LogicalPlan > > {
533731 // Split the parent predicate into individual conjunctive parts.
534- let predicates = parent_predicate
732+ let mut predicates = parent_predicate
535733 . map_or_else ( Vec :: new, |pred| split_conjunction_owned ( pred. clone ( ) ) ) ;
536734
735+ if let Some ( parent_predicate) = parent_predicate {
736+ if let Some ( remaining_predicates) =
737+ push_full_join_coalesce_filters ( & mut join, parent_predicate. clone ( ) ) ?
738+ {
739+ predicates = remaining_predicates;
740+ }
741+ }
742+
537743 // Extract conjunctions from the JOIN's ON filter, if present.
538744 let on_filters = join
539745 . filter
@@ -1447,6 +1653,7 @@ mod tests {
14471653 use crate :: test:: * ;
14481654 use crate :: OptimizerContext ;
14491655 use datafusion_expr:: test:: function_stub:: sum;
1656+ use datafusion_functions:: core:: expr_fn:: coalesce;
14501657 use insta:: assert_snapshot;
14511658
14521659 use super :: * ;
@@ -2848,6 +3055,36 @@ mod tests {
28483055 )
28493056 }
28503057
3058+ /// Filter on coalesce of join keys should be pushed to both join inputs
3059+ #[ test]
3060+ fn filter_full_join_on_coalesce ( ) -> Result < ( ) > {
3061+ let table_scan_t1 = test_table_scan_with_name ( "t1" ) ?;
3062+ let table_scan_t2 = test_table_scan_with_name ( "t2" ) ?;
3063+
3064+ let plan = LogicalPlanBuilder :: from ( table_scan_t1)
3065+ . join ( table_scan_t2, JoinType :: Full , ( vec ! [ "a" ] , vec ! [ "a" ] ) , None ) ?
3066+ . filter ( coalesce ( vec ! [ col( "t1.a" ) , col( "t2.a" ) ] ) . eq ( lit ( 1i32 ) ) ) ?
3067+ . build ( ) ?;
3068+
3069+ // not part of the test, just good to know:
3070+ assert_snapshot ! ( plan,
3071+ @r"
3072+ Filter: coalesce(t1.a, t2.a) = Int32(1)
3073+ Full Join: t1.a = t2.a
3074+ TableScan: t1
3075+ TableScan: t2
3076+ " ,
3077+ ) ;
3078+ assert_optimized_plan_equal ! (
3079+ plan,
3080+ @r"
3081+ Full Join: t1.a = t2.a
3082+ TableScan: t1, full_filters=[t1.a = Int32(1)]
3083+ TableScan: t2, full_filters=[t2.a = Int32(1)]
3084+ "
3085+ )
3086+ }
3087+
28513088 /// join filter should be completely removed after pushdown
28523089 #[ test]
28533090 fn join_filter_removed ( ) -> Result < ( ) > {
0 commit comments