Skip to content

Commit ed3e3f3

Browse files
committed
Support push-down of the filter on coalesce over join keys
1 parent 7fa2a69 commit ed3e3f3

File tree

3 files changed

+240
-3
lines changed

3 files changed

+240
-3
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/optimizer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ regex-syntax = "0.8.6"
6161
async-trait = { workspace = true }
6262
criterion = { workspace = true }
6363
ctor = { workspace = true }
64+
datafusion-functions = { workspace = true }
6465
datafusion-functions-aggregate = { workspace = true }
6566
datafusion-functions-window = { workspace = true }
6667
datafusion-functions-window-common = { workspace = true }

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 238 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_common::{
3131
assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
3232
qualified_name, Column, DFSchema, DataFusionError, Result,
3333
};
34-
use datafusion_expr::expr::WindowFunction;
34+
use datafusion_expr::expr::{Between, InList, ScalarFunction, WindowFunction};
3535
use datafusion_expr::expr_rewriter::replace_col;
3636
use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
3737
use datafusion_expr::utils::{
@@ -418,6 +418,202 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
418418
predicate
419419
}
420420

421+
/// Tracks coalesce predicates that can be pushed to each side of a FULL JOIN.
422+
struct PushDownCoalesceFilterHelper {
423+
join_keys: Vec<(Column, Column)>,
424+
left_filters: Vec<Expr>,
425+
right_filters: Vec<Expr>,
426+
remaining_filters: Vec<Expr>,
427+
}
428+
429+
impl PushDownCoalesceFilterHelper {
430+
fn new(join_keys: &[(Expr, Expr)]) -> Self {
431+
let join_keys = join_keys
432+
.iter()
433+
.filter_map(|(lhs, rhs)| {
434+
Some((lhs.try_as_col()?.clone(), rhs.try_as_col()?.clone()))
435+
})
436+
.collect();
437+
Self {
438+
join_keys,
439+
left_filters: Vec::new(),
440+
right_filters: Vec::new(),
441+
remaining_filters: Vec::new(),
442+
}
443+
}
444+
445+
fn push_columns<F: FnMut(Expr) -> Expr>(
446+
&mut self,
447+
columns: (Column, Column),
448+
mut build_filter: F,
449+
) {
450+
self.left_filters
451+
.push(build_filter(Expr::Column(columns.0)));
452+
self.right_filters
453+
.push(build_filter(Expr::Column(columns.1)));
454+
}
455+
456+
fn extract_join_columns(&self, expr: &Expr) -> Option<(Column, Column)> {
457+
if let Expr::ScalarFunction(ScalarFunction { func, args }) = expr {
458+
if func.name() != "coalesce" {
459+
return None;
460+
}
461+
if let [Expr::Column(lhs), Expr::Column(rhs)] = args.as_slice() {
462+
for (join_lhs, join_rhs) in &self.join_keys {
463+
if join_lhs == lhs && join_rhs == rhs {
464+
return Some((lhs.clone(), rhs.clone()));
465+
}
466+
if join_lhs == rhs && join_rhs == lhs {
467+
return Some((rhs.clone(), lhs.clone()));
468+
}
469+
}
470+
}
471+
}
472+
None
473+
}
474+
475+
fn push_term(&mut self, term: &Expr) {
476+
match term {
477+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
478+
if let Some(columns) = self.extract_join_columns(left) {
479+
return self.push_columns(columns, |replacement| {
480+
Expr::BinaryExpr(BinaryExpr {
481+
left: Box::new(replacement),
482+
op: *op,
483+
right: right.clone(),
484+
})
485+
});
486+
}
487+
if let Some(columns) = self.extract_join_columns(right) {
488+
return self.push_columns(columns, |replacement| {
489+
Expr::BinaryExpr(BinaryExpr {
490+
left: left.clone(),
491+
op: *op,
492+
right: Box::new(replacement),
493+
})
494+
});
495+
}
496+
}
497+
Expr::IsNull(expr) => {
498+
if let Some(columns) = self.extract_join_columns(expr) {
499+
return self.push_columns(columns, |replacement| {
500+
Expr::IsNull(Box::new(replacement))
501+
});
502+
}
503+
}
504+
Expr::IsNotNull(expr) => {
505+
if let Some(columns) = self.extract_join_columns(expr) {
506+
return self.push_columns(columns, |replacement| {
507+
Expr::IsNotNull(Box::new(replacement))
508+
});
509+
}
510+
}
511+
Expr::IsTrue(expr) => {
512+
if let Some(columns) = self.extract_join_columns(expr) {
513+
return self.push_columns(columns, |replacement| {
514+
Expr::IsTrue(Box::new(replacement))
515+
});
516+
}
517+
}
518+
Expr::IsFalse(expr) => {
519+
if let Some(columns) = self.extract_join_columns(expr) {
520+
return self.push_columns(columns, |replacement| {
521+
Expr::IsFalse(Box::new(replacement))
522+
});
523+
}
524+
}
525+
Expr::IsUnknown(expr) => {
526+
if let Some(columns) = self.extract_join_columns(expr) {
527+
return self.push_columns(columns, |replacement| {
528+
Expr::IsUnknown(Box::new(replacement))
529+
});
530+
}
531+
}
532+
Expr::IsNotTrue(expr) => {
533+
if let Some(columns) = self.extract_join_columns(expr) {
534+
return self.push_columns(columns, |replacement| {
535+
Expr::IsNotTrue(Box::new(replacement))
536+
});
537+
}
538+
}
539+
Expr::IsNotFalse(expr) => {
540+
if let Some(columns) = self.extract_join_columns(expr) {
541+
return self.push_columns(columns, |replacement| {
542+
Expr::IsNotFalse(Box::new(replacement))
543+
});
544+
}
545+
}
546+
Expr::IsNotUnknown(expr) => {
547+
if let Some(columns) = self.extract_join_columns(expr) {
548+
return self.push_columns(columns, |replacement| {
549+
Expr::IsNotUnknown(Box::new(replacement))
550+
});
551+
}
552+
}
553+
Expr::Between(between) => {
554+
if let Some(columns) = self.extract_join_columns(&between.expr) {
555+
return self.push_columns(columns, |replacement| {
556+
Expr::Between(Between {
557+
expr: Box::new(replacement),
558+
negated: between.negated,
559+
low: between.low.clone(),
560+
high: between.high.clone(),
561+
})
562+
});
563+
}
564+
}
565+
Expr::InList(in_list) => {
566+
if let Some(columns) = self.extract_join_columns(&in_list.expr) {
567+
return self.push_columns(columns, |replacement| {
568+
Expr::InList(InList {
569+
expr: Box::new(replacement),
570+
list: in_list.list.clone(),
571+
negated: in_list.negated,
572+
})
573+
});
574+
}
575+
}
576+
_ => {}
577+
}
578+
self.remaining_filters.push(term.clone());
579+
}
580+
581+
fn push_predicate(
582+
mut self,
583+
predicate: Expr,
584+
) -> Result<(Option<Expr>, Option<Expr>, Vec<Expr>)> {
585+
let predicates = split_conjunction_owned(predicate);
586+
let terms = simplify_predicates(predicates)?;
587+
for term in terms {
588+
self.push_term(&term);
589+
}
590+
Ok((
591+
conjunction(self.left_filters),
592+
conjunction(self.right_filters),
593+
self.remaining_filters,
594+
))
595+
}
596+
}
597+
598+
fn push_full_join_coalesce_filters(
599+
join: &mut Join,
600+
predicate: Expr,
601+
) -> Result<Option<Vec<Expr>>> {
602+
let (Some(left), Some(right), remaining) =
603+
PushDownCoalesceFilterHelper::new(&join.on).push_predicate(predicate)?
604+
else {
605+
return Ok(None);
606+
};
607+
608+
let left_input = Arc::clone(&join.left);
609+
join.left = Arc::new(make_filter(left, left_input)?);
610+
611+
let right_input = Arc::clone(&join.right);
612+
join.right = Arc::new(make_filter(right, right_input)?);
613+
614+
Ok(Some(remaining))
615+
}
616+
421617
/// push down join/cross-join
422618
fn push_down_all_join(
423619
predicates: Vec<Expr>,
@@ -527,13 +723,21 @@ fn push_down_all_join(
527723
}
528724

529725
fn push_down_join(
530-
join: Join,
726+
mut join: Join,
531727
parent_predicate: Option<&Expr>,
532728
) -> Result<Transformed<LogicalPlan>> {
533729
// Split the parent predicate into individual conjunctive parts.
534-
let predicates = parent_predicate
730+
let mut predicates = parent_predicate
535731
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536732

733+
if let Some(parent_predicate) = parent_predicate {
734+
if let Some(remaining_predicates) =
735+
push_full_join_coalesce_filters(&mut join, parent_predicate.clone())?
736+
{
737+
predicates = remaining_predicates;
738+
}
739+
}
740+
537741
// Extract conjunctions from the JOIN's ON filter, if present.
538742
let on_filters = join
539743
.filter
@@ -1447,6 +1651,7 @@ mod tests {
14471651
use crate::test::*;
14481652
use crate::OptimizerContext;
14491653
use datafusion_expr::test::function_stub::sum;
1654+
use datafusion_functions::core::expr_fn::coalesce;
14501655
use insta::assert_snapshot;
14511656

14521657
use super::*;
@@ -2848,6 +3053,36 @@ mod tests {
28483053
)
28493054
}
28503055

3056+
/// Filter on coalesce of join keys should be pushed to both join inputs
3057+
#[test]
3058+
fn filter_full_join_on_coalesce() -> Result<()> {
3059+
let table_scan_t1 = test_table_scan_with_name("t1")?;
3060+
let table_scan_t2 = test_table_scan_with_name("t2")?;
3061+
3062+
let plan = LogicalPlanBuilder::from(table_scan_t1)
3063+
.join(table_scan_t2, JoinType::Full, (vec!["a"], vec!["a"]), None)?
3064+
.filter(coalesce(vec![col("t1.a"), col("t2.a")]).eq(lit(1i32)))?
3065+
.build()?;
3066+
3067+
// not part of the test, just good to know:
3068+
assert_snapshot!(plan,
3069+
@r"
3070+
Filter: coalesce(t1.a, t2.a) = Int32(1)
3071+
Full Join: t1.a = t2.a
3072+
TableScan: t1
3073+
TableScan: t2
3074+
",
3075+
);
3076+
assert_optimized_plan_equal!(
3077+
plan,
3078+
@r"
3079+
Full Join: t1.a = t2.a
3080+
TableScan: t1, full_filters=[t1.a = Int32(1)]
3081+
TableScan: t2, full_filters=[t2.a = Int32(1)]
3082+
"
3083+
)
3084+
}
3085+
28513086
/// join filter should be completely removed after pushdown
28523087
#[test]
28533088
fn join_filter_removed() -> Result<()> {

0 commit comments

Comments
 (0)