Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 96 additions & 5 deletions src/query/parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};
use sha3::{Digest, Keccak256};
use sqlparser::ast::{visit_expressions, BinaryOperator, Expr, Value};
use sqlparser::ast::{visit_expressions, BinaryOperator, Expr, SetExpr, Statement, Value};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -576,6 +576,10 @@ fn extract_ident_from_expr(expr: &Expr, columns: &mut HashSet<String>) {
/// Returns SQL fragments like `block_num >= 100`, `address = '0x...'` etc.
/// Only extracts simple comparisons (=, >=, <=, >, <) and IN lists on known
/// raw columns. Decoded event columns are NOT extracted.
///
/// IMPORTANT: Only top-level AND conjuncts are extracted. Predicates inside OR,
/// CASE, or other complex expressions are NOT pushed down, because converting
/// `WHERE a OR b` into `WHERE a AND b` would silently change query semantics.
pub fn extract_raw_column_predicates(sql: &str) -> Vec<String> {
let mut predicates = Vec::new();

Expand All @@ -585,15 +589,50 @@ pub fn extract_raw_column_predicates(sql: &str) -> Vec<String> {
};

for stmt in &statements {
let _ = visit_expressions(stmt, |expr| {
extract_raw_predicate(expr, &mut predicates);
ControlFlow::<()>::Continue(())
});
if let Some(where_expr) = extract_where_clause(stmt) {
collect_and_conjuncts(where_expr, &mut predicates);
}
}

predicates
}

/// Extract the WHERE clause expression from a SELECT statement.
fn extract_where_clause(stmt: &Statement) -> Option<&Expr> {
match stmt {
Statement::Query(query) => {
if let SetExpr::Select(select) = query.body.as_ref() {
select.selection.as_ref()
} else {
None
}
}
_ => None,
}
}

/// Walk only top-level AND conjuncts, extracting raw predicates from each leaf.
/// Stops recursing at OR, CASE, or any non-AND binary operator to avoid
/// incorrectly converting disjunctions into conjunctions.
fn collect_and_conjuncts(expr: &Expr, predicates: &mut Vec<String>) {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
collect_and_conjuncts(left, predicates);
collect_and_conjuncts(right, predicates);
}
Expr::Nested(inner) => {
collect_and_conjuncts(inner, predicates);
}
other => {
extract_raw_predicate(other, predicates);
}
}
}

/// Extract a single raw-column predicate from an expression.
fn extract_raw_predicate(expr: &Expr, predicates: &mut Vec<String>) {
match expr {
Expand Down Expand Up @@ -1532,4 +1571,56 @@ mod tests {
assert_eq!(without, with);
}

// ========================================================================
// Pushdown Safety: OR, CASE, and mixed expressions
// ========================================================================

#[test]
fn test_or_predicates_not_pushed_down() {
let preds = extract_raw_column_predicates(
"SELECT * FROM Transfer WHERE block_num = 1 OR block_num = 2",
);
assert!(preds.is_empty(), "OR predicates must not be pushed down, got: {preds:?}");
}

#[test]
fn test_case_predicates_not_pushed_down() {
let preds = extract_raw_column_predicates(
"SELECT * FROM Transfer WHERE block_num = CASE WHEN 1=1 THEN 100 ELSE 200 END",
);
assert!(preds.is_empty(), "CASE predicates must not be pushed down, got: {preds:?}");
}

#[test]
fn test_simple_and_predicates_pushed_down() {
let preds = extract_raw_column_predicates(
"SELECT * FROM Transfer WHERE block_num >= 100 AND block_num <= 200 AND address = '0xABC'",
);
assert_eq!(preds.len(), 3);
assert!(preds.contains(&"block_num >= 100".to_string()));
assert!(preds.contains(&"block_num <= 200".to_string()));
assert!(preds.contains(&"address = '0xABC'".to_string()));
}

#[test]
fn test_mixed_and_or_only_pushes_safe_conjuncts() {
// `block_num >= 100 AND (address = '0xA' OR address = '0xB')`
// Only the top-level AND conjunct `block_num >= 100` is safe to push down.
// The OR branch must NOT be pushed down.
let preds = extract_raw_column_predicates(
"SELECT * FROM Transfer WHERE block_num >= 100 AND (address = '0xA' OR address = '0xB')",
);
assert_eq!(preds.len(), 1, "Only safe AND conjuncts should be pushed, got: {preds:?}");
assert!(preds.contains(&"block_num >= 100".to_string()));
}

#[test]
fn test_nested_or_inside_and_not_pushed_down() {
let preds = extract_raw_column_predicates(
"SELECT * FROM Transfer WHERE (block_num = 1 OR block_num = 2) AND address = '0xABC'",
);
assert_eq!(preds.len(), 1, "Only simple AND conjuncts should be pushed, got: {preds:?}");
assert!(preds.contains(&"address = '0xABC'".to_string()));
}

}
Loading