Skip to content

Commit 875a9b4

Browse files
committed
Support push-down of the filter on coalesce over join keys
1 parent 7fa2a69 commit 875a9b4

File tree

3 files changed

+242
-3
lines changed

3 files changed

+242
-3
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/optimizer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ regex-syntax = "0.8.6"
6161
async-trait = { workspace = true }
6262
criterion = { workspace = true }
6363
ctor = { workspace = true }
64+
datafusion-functions = { workspace = true }
6465
datafusion-functions-aggregate = { workspace = true }
6566
datafusion-functions-window = { workspace = true }
6667
datafusion-functions-window-common = { workspace = true }

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_common::{
3131
assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
3232
qualified_name, Column, DFSchema, DataFusionError, Result,
3333
};
34-
use datafusion_expr::expr::WindowFunction;
34+
use datafusion_expr::expr::{Between, InList, ScalarFunction, WindowFunction};
3535
use datafusion_expr::expr_rewriter::replace_col;
3636
use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
3737
use datafusion_expr::utils::{
@@ -418,6 +418,204 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
418418
predicate
419419
}
420420

421+
/// Tracks coalesce predicates that can be pushed to each side of a FULL JOIN.
422+
struct PushDownCoalesceFilterHelper {
423+
join_keys: Vec<(Column, Column)>,
424+
left_filters: Vec<Expr>,
425+
right_filters: Vec<Expr>,
426+
remaining_filters: Vec<Expr>,
427+
}
428+
429+
impl PushDownCoalesceFilterHelper {
430+
fn new(join_keys: &[(Expr, Expr)]) -> Self {
431+
let join_keys = join_keys
432+
.iter()
433+
.filter_map(|(lhs, rhs)| {
434+
Some((lhs.try_as_col()?.clone(), rhs.try_as_col()?.clone()))
435+
})
436+
.collect();
437+
Self {
438+
join_keys,
439+
left_filters: Vec::new(),
440+
right_filters: Vec::new(),
441+
remaining_filters: Vec::new(),
442+
}
443+
}
444+
445+
fn push_columns<F: FnMut(Expr) -> Expr>(
446+
&mut self,
447+
columns: (Column, Column),
448+
mut build_filter: F,
449+
) {
450+
self.left_filters
451+
.push(build_filter(Expr::Column(columns.0)));
452+
self.right_filters
453+
.push(build_filter(Expr::Column(columns.1)));
454+
}
455+
456+
fn extract_join_columns(&self, expr: &Expr) -> Option<(Column, Column)> {
457+
if let Expr::ScalarFunction(ScalarFunction { func, args }) = expr {
458+
if func.name() != "coalesce" {
459+
return None;
460+
}
461+
if let [Expr::Column(lhs), Expr::Column(rhs)] = args.as_slice() {
462+
for (join_lhs, join_rhs) in &self.join_keys {
463+
if join_lhs == lhs && join_rhs == rhs {
464+
return Some((lhs.clone(), rhs.clone()));
465+
}
466+
if join_lhs == rhs && join_rhs == lhs {
467+
return Some((rhs.clone(), lhs.clone()));
468+
}
469+
}
470+
}
471+
}
472+
None
473+
}
474+
475+
fn push_term(&mut self, term: &Expr) {
476+
match term {
477+
Expr::BinaryExpr(BinaryExpr { left, op, right })
478+
if op.supports_propagation() =>
479+
{
480+
if let Some(columns) = self.extract_join_columns(left) {
481+
return self.push_columns(columns, |replacement| {
482+
Expr::BinaryExpr(BinaryExpr {
483+
left: Box::new(replacement),
484+
op: *op,
485+
right: right.clone(),
486+
})
487+
});
488+
}
489+
if let Some(columns) = self.extract_join_columns(right) {
490+
return self.push_columns(columns, |replacement| {
491+
Expr::BinaryExpr(BinaryExpr {
492+
left: left.clone(),
493+
op: *op,
494+
right: Box::new(replacement),
495+
})
496+
});
497+
}
498+
}
499+
Expr::IsNull(expr) => {
500+
if let Some(columns) = self.extract_join_columns(expr) {
501+
return self.push_columns(columns, |replacement| {
502+
Expr::IsNull(Box::new(replacement))
503+
});
504+
}
505+
}
506+
Expr::IsNotNull(expr) => {
507+
if let Some(columns) = self.extract_join_columns(expr) {
508+
return self.push_columns(columns, |replacement| {
509+
Expr::IsNotNull(Box::new(replacement))
510+
});
511+
}
512+
}
513+
Expr::IsTrue(expr) => {
514+
if let Some(columns) = self.extract_join_columns(expr) {
515+
return self.push_columns(columns, |replacement| {
516+
Expr::IsTrue(Box::new(replacement))
517+
});
518+
}
519+
}
520+
Expr::IsFalse(expr) => {
521+
if let Some(columns) = self.extract_join_columns(expr) {
522+
return self.push_columns(columns, |replacement| {
523+
Expr::IsFalse(Box::new(replacement))
524+
});
525+
}
526+
}
527+
Expr::IsUnknown(expr) => {
528+
if let Some(columns) = self.extract_join_columns(expr) {
529+
return self.push_columns(columns, |replacement| {
530+
Expr::IsUnknown(Box::new(replacement))
531+
});
532+
}
533+
}
534+
Expr::IsNotTrue(expr) => {
535+
if let Some(columns) = self.extract_join_columns(expr) {
536+
return self.push_columns(columns, |replacement| {
537+
Expr::IsNotTrue(Box::new(replacement))
538+
});
539+
}
540+
}
541+
Expr::IsNotFalse(expr) => {
542+
if let Some(columns) = self.extract_join_columns(expr) {
543+
return self.push_columns(columns, |replacement| {
544+
Expr::IsNotFalse(Box::new(replacement))
545+
});
546+
}
547+
}
548+
Expr::IsNotUnknown(expr) => {
549+
if let Some(columns) = self.extract_join_columns(expr) {
550+
return self.push_columns(columns, |replacement| {
551+
Expr::IsNotUnknown(Box::new(replacement))
552+
});
553+
}
554+
}
555+
Expr::Between(between) => {
556+
if let Some(columns) = self.extract_join_columns(&between.expr) {
557+
return self.push_columns(columns, |replacement| {
558+
Expr::Between(Between {
559+
expr: Box::new(replacement),
560+
negated: between.negated,
561+
low: between.low.clone(),
562+
high: between.high.clone(),
563+
})
564+
});
565+
}
566+
}
567+
Expr::InList(in_list) => {
568+
if let Some(columns) = self.extract_join_columns(&in_list.expr) {
569+
return self.push_columns(columns, |replacement| {
570+
Expr::InList(InList {
571+
expr: Box::new(replacement),
572+
list: in_list.list.clone(),
573+
negated: in_list.negated,
574+
})
575+
});
576+
}
577+
}
578+
_ => {}
579+
}
580+
self.remaining_filters.push(term.clone());
581+
}
582+
583+
fn push_predicate(
584+
mut self,
585+
predicate: Expr,
586+
) -> Result<(Option<Expr>, Option<Expr>, Vec<Expr>)> {
587+
let predicates = split_conjunction_owned(predicate);
588+
let terms = simplify_predicates(predicates)?;
589+
for term in terms {
590+
self.push_term(&term);
591+
}
592+
Ok((
593+
conjunction(self.left_filters),
594+
conjunction(self.right_filters),
595+
self.remaining_filters,
596+
))
597+
}
598+
}
599+
600+
fn push_full_join_coalesce_filters(
601+
join: &mut Join,
602+
predicate: Expr,
603+
) -> Result<Option<Vec<Expr>>> {
604+
let (Some(left), Some(right), remaining) =
605+
PushDownCoalesceFilterHelper::new(&join.on).push_predicate(predicate)?
606+
else {
607+
return Ok(None);
608+
};
609+
610+
let left_input = Arc::clone(&join.left);
611+
join.left = Arc::new(make_filter(left, left_input)?);
612+
613+
let right_input = Arc::clone(&join.right);
614+
join.right = Arc::new(make_filter(right, right_input)?);
615+
616+
Ok(Some(remaining))
617+
}
618+
421619
/// push down join/cross-join
422620
fn push_down_all_join(
423621
predicates: Vec<Expr>,
@@ -527,13 +725,21 @@ fn push_down_all_join(
527725
}
528726

529727
fn push_down_join(
530-
join: Join,
728+
mut join: Join,
531729
parent_predicate: Option<&Expr>,
532730
) -> Result<Transformed<LogicalPlan>> {
533731
// Split the parent predicate into individual conjunctive parts.
534-
let predicates = parent_predicate
732+
let mut predicates = parent_predicate
535733
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536734

735+
if let Some(parent_predicate) = parent_predicate {
736+
if let Some(remaining_predicates) =
737+
push_full_join_coalesce_filters(&mut join, parent_predicate.clone())?
738+
{
739+
predicates = remaining_predicates;
740+
}
741+
}
742+
537743
// Extract conjunctions from the JOIN's ON filter, if present.
538744
let on_filters = join
539745
.filter
@@ -1447,6 +1653,7 @@ mod tests {
14471653
use crate::test::*;
14481654
use crate::OptimizerContext;
14491655
use datafusion_expr::test::function_stub::sum;
1656+
use datafusion_functions::core::expr_fn::coalesce;
14501657
use insta::assert_snapshot;
14511658

14521659
use super::*;
@@ -2848,6 +3055,36 @@ mod tests {
28483055
)
28493056
}
28503057

3058+
/// Filter on coalesce of join keys should be pushed to both join inputs
3059+
#[test]
3060+
fn filter_full_join_on_coalesce() -> Result<()> {
3061+
let table_scan_t1 = test_table_scan_with_name("t1")?;
3062+
let table_scan_t2 = test_table_scan_with_name("t2")?;
3063+
3064+
let plan = LogicalPlanBuilder::from(table_scan_t1)
3065+
.join(table_scan_t2, JoinType::Full, (vec!["a"], vec!["a"]), None)?
3066+
.filter(coalesce(vec![col("t1.a"), col("t2.a")]).eq(lit(1i32)))?
3067+
.build()?;
3068+
3069+
// not part of the test, just good to know:
3070+
assert_snapshot!(plan,
3071+
@r"
3072+
Filter: coalesce(t1.a, t2.a) = Int32(1)
3073+
Full Join: t1.a = t2.a
3074+
TableScan: t1
3075+
TableScan: t2
3076+
",
3077+
);
3078+
assert_optimized_plan_equal!(
3079+
plan,
3080+
@r"
3081+
Full Join: t1.a = t2.a
3082+
TableScan: t1, full_filters=[t1.a = Int32(1)]
3083+
TableScan: t2, full_filters=[t2.a = Int32(1)]
3084+
"
3085+
)
3086+
}
3087+
28513088
/// join filter should be completely removed after pushdown
28523089
#[test]
28533090
fn join_filter_removed() -> Result<()> {

0 commit comments

Comments
 (0)