@@ -31,7 +31,7 @@ use datafusion_common::{
3131 assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
3232 qualified_name, Column , DFSchema , DataFusionError , Result ,
3333} ;
34- use datafusion_expr:: expr:: WindowFunction ;
34+ use datafusion_expr:: expr:: { Between , InList , ScalarFunction , WindowFunction } ;
3535use datafusion_expr:: expr_rewriter:: replace_col;
3636use datafusion_expr:: logical_plan:: { Join , JoinType , LogicalPlan , TableScan , Union } ;
3737use datafusion_expr:: utils:: {
@@ -418,6 +418,202 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
418418 predicate
419419}
420420
421+ /// Tracks coalesce predicates that can be pushed to each side of a FULL JOIN.
422+ struct PushDownCoalesceFilterHelper {
423+ join_keys : Vec < ( Column , Column ) > ,
424+ left_filters : Vec < Expr > ,
425+ right_filters : Vec < Expr > ,
426+ remaining_filters : Vec < Expr > ,
427+ }
428+
429+ impl PushDownCoalesceFilterHelper {
430+ fn new ( join_keys : & [ ( Expr , Expr ) ] ) -> Self {
431+ let join_keys = join_keys
432+ . iter ( )
433+ . filter_map ( |( lhs, rhs) | {
434+ Some ( ( lhs. try_as_col ( ) ?. clone ( ) , rhs. try_as_col ( ) ?. clone ( ) ) )
435+ } )
436+ . collect ( ) ;
437+ Self {
438+ join_keys,
439+ left_filters : Vec :: new ( ) ,
440+ right_filters : Vec :: new ( ) ,
441+ remaining_filters : Vec :: new ( ) ,
442+ }
443+ }
444+
445+ fn push_columns < F : FnMut ( Expr ) -> Expr > (
446+ & mut self ,
447+ columns : ( Column , Column ) ,
448+ mut build_filter : F ,
449+ ) {
450+ self . left_filters
451+ . push ( build_filter ( Expr :: Column ( columns. 0 ) ) ) ;
452+ self . right_filters
453+ . push ( build_filter ( Expr :: Column ( columns. 1 ) ) ) ;
454+ }
455+
456+ fn extract_join_columns ( & self , expr : & Expr ) -> Option < ( Column , Column ) > {
457+ if let Expr :: ScalarFunction ( ScalarFunction { func, args } ) = expr {
458+ if func. name ( ) != "coalesce" {
459+ return None ;
460+ }
461+ if let [ Expr :: Column ( lhs) , Expr :: Column ( rhs) ] = args. as_slice ( ) {
462+ for ( join_lhs, join_rhs) in & self . join_keys {
463+ if join_lhs == lhs && join_rhs == rhs {
464+ return Some ( ( lhs. clone ( ) , rhs. clone ( ) ) ) ;
465+ }
466+ if join_lhs == rhs && join_rhs == lhs {
467+ return Some ( ( rhs. clone ( ) , lhs. clone ( ) ) ) ;
468+ }
469+ }
470+ }
471+ }
472+ None
473+ }
474+
475+ fn push_term ( & mut self , term : & Expr ) {
476+ match term {
477+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
478+ if let Some ( columns) = self . extract_join_columns ( left) {
479+ return self . push_columns ( columns, |replacement| {
480+ Expr :: BinaryExpr ( BinaryExpr {
481+ left : Box :: new ( replacement) ,
482+ op : * op,
483+ right : right. clone ( ) ,
484+ } )
485+ } ) ;
486+ }
487+ if let Some ( columns) = self . extract_join_columns ( right) {
488+ return self . push_columns ( columns, |replacement| {
489+ Expr :: BinaryExpr ( BinaryExpr {
490+ left : left. clone ( ) ,
491+ op : * op,
492+ right : Box :: new ( replacement) ,
493+ } )
494+ } ) ;
495+ }
496+ }
497+ Expr :: IsNull ( expr) => {
498+ if let Some ( columns) = self . extract_join_columns ( expr) {
499+ return self . push_columns ( columns, |replacement| {
500+ Expr :: IsNull ( Box :: new ( replacement) )
501+ } ) ;
502+ }
503+ }
504+ Expr :: IsNotNull ( expr) => {
505+ if let Some ( columns) = self . extract_join_columns ( expr) {
506+ return self . push_columns ( columns, |replacement| {
507+ Expr :: IsNotNull ( Box :: new ( replacement) )
508+ } ) ;
509+ }
510+ }
511+ Expr :: IsTrue ( expr) => {
512+ if let Some ( columns) = self . extract_join_columns ( expr) {
513+ return self . push_columns ( columns, |replacement| {
514+ Expr :: IsTrue ( Box :: new ( replacement) )
515+ } ) ;
516+ }
517+ }
518+ Expr :: IsFalse ( expr) => {
519+ if let Some ( columns) = self . extract_join_columns ( expr) {
520+ return self . push_columns ( columns, |replacement| {
521+ Expr :: IsFalse ( Box :: new ( replacement) )
522+ } ) ;
523+ }
524+ }
525+ Expr :: IsUnknown ( expr) => {
526+ if let Some ( columns) = self . extract_join_columns ( expr) {
527+ return self . push_columns ( columns, |replacement| {
528+ Expr :: IsUnknown ( Box :: new ( replacement) )
529+ } ) ;
530+ }
531+ }
532+ Expr :: IsNotTrue ( expr) => {
533+ if let Some ( columns) = self . extract_join_columns ( expr) {
534+ return self . push_columns ( columns, |replacement| {
535+ Expr :: IsNotTrue ( Box :: new ( replacement) )
536+ } ) ;
537+ }
538+ }
539+ Expr :: IsNotFalse ( expr) => {
540+ if let Some ( columns) = self . extract_join_columns ( expr) {
541+ return self . push_columns ( columns, |replacement| {
542+ Expr :: IsNotFalse ( Box :: new ( replacement) )
543+ } ) ;
544+ }
545+ }
546+ Expr :: IsNotUnknown ( expr) => {
547+ if let Some ( columns) = self . extract_join_columns ( expr) {
548+ return self . push_columns ( columns, |replacement| {
549+ Expr :: IsNotUnknown ( Box :: new ( replacement) )
550+ } ) ;
551+ }
552+ }
553+ Expr :: Between ( between) => {
554+ if let Some ( columns) = self . extract_join_columns ( & between. expr ) {
555+ return self . push_columns ( columns, |replacement| {
556+ Expr :: Between ( Between {
557+ expr : Box :: new ( replacement) ,
558+ negated : between. negated ,
559+ low : between. low . clone ( ) ,
560+ high : between. high . clone ( ) ,
561+ } )
562+ } ) ;
563+ }
564+ }
565+ Expr :: InList ( in_list) => {
566+ if let Some ( columns) = self . extract_join_columns ( & in_list. expr ) {
567+ return self . push_columns ( columns, |replacement| {
568+ Expr :: InList ( InList {
569+ expr : Box :: new ( replacement) ,
570+ list : in_list. list . clone ( ) ,
571+ negated : in_list. negated ,
572+ } )
573+ } ) ;
574+ }
575+ }
576+ _ => { }
577+ }
578+ self . remaining_filters . push ( term. clone ( ) ) ;
579+ }
580+
581+ fn push_predicate (
582+ mut self ,
583+ predicate : Expr ,
584+ ) -> Result < ( Option < Expr > , Option < Expr > , Vec < Expr > ) > {
585+ let predicates = split_conjunction_owned ( predicate) ;
586+ let terms = simplify_predicates ( predicates) ?;
587+ for term in terms {
588+ self . push_term ( & term) ;
589+ }
590+ Ok ( (
591+ conjunction ( self . left_filters ) ,
592+ conjunction ( self . right_filters ) ,
593+ self . remaining_filters ,
594+ ) )
595+ }
596+ }
597+
598+ fn push_full_join_coalesce_filters (
599+ join : & mut Join ,
600+ predicate : Expr ,
601+ ) -> Result < Option < Vec < Expr > > > {
602+ let ( Some ( left) , Some ( right) , remaining) =
603+ PushDownCoalesceFilterHelper :: new ( & join. on ) . push_predicate ( predicate) ?
604+ else {
605+ return Ok ( None ) ;
606+ } ;
607+
608+ let left_input = Arc :: clone ( & join. left ) ;
609+ join. left = Arc :: new ( make_filter ( left, left_input) ?) ;
610+
611+ let right_input = Arc :: clone ( & join. right ) ;
612+ join. right = Arc :: new ( make_filter ( right, right_input) ?) ;
613+
614+ Ok ( Some ( remaining) )
615+ }
616+
421617/// push down join/cross-join
422618fn push_down_all_join (
423619 predicates : Vec < Expr > ,
@@ -527,13 +723,21 @@ fn push_down_all_join(
527723}
528724
529725fn push_down_join (
530- join : Join ,
726+ mut join : Join ,
531727 parent_predicate : Option < & Expr > ,
532728) -> Result < Transformed < LogicalPlan > > {
533729 // Split the parent predicate into individual conjunctive parts.
534- let predicates = parent_predicate
730+ let mut predicates = parent_predicate
535731 . map_or_else ( Vec :: new, |pred| split_conjunction_owned ( pred. clone ( ) ) ) ;
536732
733+ if let Some ( parent_predicate) = parent_predicate {
734+ if let Some ( remaining_predicates) =
735+ push_full_join_coalesce_filters ( & mut join, parent_predicate. clone ( ) ) ?
736+ {
737+ predicates = remaining_predicates;
738+ }
739+ }
740+
537741 // Extract conjunctions from the JOIN's ON filter, if present.
538742 let on_filters = join
539743 . filter
@@ -1447,6 +1651,7 @@ mod tests {
14471651 use crate :: test:: * ;
14481652 use crate :: OptimizerContext ;
14491653 use datafusion_expr:: test:: function_stub:: sum;
1654+ use datafusion_functions:: core:: expr_fn:: coalesce;
14501655 use insta:: assert_snapshot;
14511656
14521657 use super :: * ;
@@ -2848,6 +3053,36 @@ mod tests {
28483053 )
28493054 }
28503055
3056+ /// Filter on coalesce of join keys should be pushed to both join inputs
3057+ #[ test]
3058+ fn filter_full_join_on_coalesce ( ) -> Result < ( ) > {
3059+ let table_scan_t1 = test_table_scan_with_name ( "t1" ) ?;
3060+ let table_scan_t2 = test_table_scan_with_name ( "t2" ) ?;
3061+
3062+ let plan = LogicalPlanBuilder :: from ( table_scan_t1)
3063+ . join ( table_scan_t2, JoinType :: Full , ( vec ! [ "a" ] , vec ! [ "a" ] ) , None ) ?
3064+ . filter ( coalesce ( vec ! [ col( "t1.a" ) , col( "t2.a" ) ] ) . eq ( lit ( 1i32 ) ) ) ?
3065+ . build ( ) ?;
3066+
3067+ // not part of the test, just good to know:
3068+ assert_snapshot ! ( plan,
3069+ @r"
3070+ Filter: coalesce(t1.a, t2.a) = Int32(1)
3071+ Full Join: t1.a = t2.a
3072+ TableScan: t1
3073+ TableScan: t2
3074+ " ,
3075+ ) ;
3076+ assert_optimized_plan_equal ! (
3077+ plan,
3078+ @r"
3079+ Full Join: t1.a = t2.a
3080+ TableScan: t1, full_filters=[t1.a = Int32(1)]
3081+ TableScan: t2, full_filters=[t2.a = Int32(1)]
3082+ "
3083+ )
3084+ }
3085+
28513086 /// join filter should be completely removed after pushdown
28523087 #[ test]
28533088 fn join_filter_removed ( ) -> Result < ( ) > {
0 commit comments