Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 165 additions & 6 deletions src/query/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::HashSet;

use anyhow::{anyhow, Result};
use sqlparser::ast::{
Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, ObjectName, Query, SetExpr,
Statement, TableFactor, TableWithJoins,
Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArguments,
GroupByWithModifier, ObjectName, Query, SetExpr, Statement, TableFactor, TableWithJoins,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -207,10 +207,16 @@ fn validate_set_expr(
validate_expr(selection, cte_names, depth)?;
}

// Validate GROUP BY expressions
if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
for expr in exprs {
validate_expr(expr, cte_names, depth)?;
// Validate GROUP BY expressions and modifiers (GROUPING SETS, ROLLUP, CUBE)
match &select.group_by {
sqlparser::ast::GroupByExpr::Expressions(exprs, modifiers) => {
for expr in exprs {
validate_expr(expr, cte_names, depth)?;
}
validate_group_by_modifiers(modifiers, cte_names, depth)?;
}
sqlparser::ast::GroupByExpr::All(modifiers) => {
validate_group_by_modifiers(modifiers, cte_names, depth)?;
}
}

Expand Down Expand Up @@ -610,13 +616,76 @@ fn is_allowed_function(name: &str) -> bool {
ALLOWED_FUNCTIONS.contains(&bare_name)
}

/// Recursively validate expressions inside GROUP BY modifiers (GROUPING SETS, etc.)
fn validate_group_by_modifiers(
modifiers: &[GroupByWithModifier],
cte_names: &HashSet<String>,
depth: usize,
) -> Result<()> {
for modifier in modifiers {
if let GroupByWithModifier::GroupingSets(expr) = modifier {
validate_expr(expr, cte_names, depth)?;
}
}
Ok(())
}

/// Maximum allowed length argument for string amplification functions (lpad, rpad, repeat).
const MAX_STRING_PAD_LENGTH: i64 = 100_000;

/// Functions whose first numeric argument (length/count) must be capped to prevent
/// memory exhaustion via string amplification.
const STRING_AMPLIFICATION_FUNCTIONS: &[&str] = &["lpad", "rpad", "repeat"];

/// Validate that string amplification functions (lpad, rpad, repeat) don't have
/// excessively large length arguments that could exhaust memory.
fn validate_string_amplification(func_name: &str, func: &Function) -> Result<()> {
let bare_name = func_name.rsplit('.').next().unwrap_or(func_name);
if !STRING_AMPLIFICATION_FUNCTIONS.contains(&bare_name) {
return Ok(());
}

if let FunctionArguments::List(arg_list) = &func.args {
// The length/count argument is the 2nd arg for lpad/rpad, 2nd for repeat
let length_arg_idx = match bare_name {
"lpad" | "rpad" => 1, // lpad(string, length, [fill])
"repeat" => 1, // repeat(string, count)
_ => return Ok(()),
};

if let Some(arg) = arg_list.args.get(length_arg_idx) {
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(v)))
| FunctionArg::Named {
arg: FunctionArgExpr::Expr(Expr::Value(v)),
..
} = arg
{
if let sqlparser::ast::Value::Number(n, _) = &v.value {
if let Ok(num) = n.parse::<i64>() {
if num > MAX_STRING_PAD_LENGTH {
return Err(anyhow!(
"Function '{bare_name}' length argument {num} exceeds maximum ({MAX_STRING_PAD_LENGTH})"
));
}
}
}
}
}
}

Ok(())
}

fn validate_function(func: &Function, cte_names: &HashSet<String>, depth: usize) -> Result<()> {
let func_name = func.name.to_string().to_lowercase();

if !is_allowed_function(&func_name) {
return Err(anyhow!("Function '{}' is not allowed", func_name));
}

// Check string amplification functions for excessive length arguments
validate_string_amplification(&func_name, func)?;

if let FunctionArguments::List(arg_list) = &func.args {
for arg in &arg_list.args {
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
Expand All @@ -628,6 +697,15 @@ fn validate_function(func: &Function, cte_names: &HashSet<String>, depth: usize)
validate_expr(expr, cte_names, depth)?;
}
}

// Validate expressions inside function argument clauses (e.g. ORDER BY within aggregates)
for clause in &arg_list.clauses {
if let FunctionArgumentClause::OrderBy(order_exprs) = clause {
for order_expr in order_exprs {
validate_expr(&order_expr.expr, cte_names, depth)?;
}
}
}
}

// Validate FILTER (WHERE ...) clause
Expand Down Expand Up @@ -1014,4 +1092,85 @@ mod tests {
validate_query("SELECT * FROM blocks FETCH FIRST 10 ROWS ONLY").is_err()
);
}

// === Audit finding: GROUPING SETS bypasses function allowlist ===

#[test]
fn test_rejects_set_config_in_grouping_sets() {
assert!(validate_query(
"SELECT 1 FROM blocks GROUP BY ALL GROUPING SETS ((set_config('a','b',true)))"
)
.is_err());
}

#[test]
fn test_rejects_dangerous_function_in_grouping_sets_with_expressions() {
assert!(validate_query(
"SELECT num FROM blocks GROUP BY num GROUPING SETS ((pg_sleep(1)))"
)
.is_err());
}

#[test]
fn test_allows_normal_group_by() {
assert!(validate_query("SELECT num, COUNT(*) FROM blocks GROUP BY num").is_ok());
assert!(validate_query(
"SELECT num, COUNT(*) FROM blocks GROUP BY ALL"
)
.is_ok());
}

// === Audit finding: Aggregate ORDER BY bypasses function allowlist ===

#[test]
fn test_rejects_set_config_in_aggregate_order_by() {
assert!(validate_query(
"SELECT string_agg(hash, ',' ORDER BY set_config('a','b',true)) FROM blocks"
)
.is_err());
}

#[test]
fn test_rejects_pg_sleep_in_aggregate_order_by() {
assert!(validate_query(
"SELECT array_agg(num ORDER BY pg_sleep(1)) FROM blocks"
)
.is_err());
}

#[test]
fn test_allows_normal_aggregate_with_order_by() {
assert!(validate_query(
"SELECT string_agg(hash, ',' ORDER BY num) FROM blocks"
)
.is_ok());
assert!(validate_query(
"SELECT array_agg(num ORDER BY num DESC) FROM blocks"
)
.is_ok());
}

// === Audit finding: lpad/rpad/repeat memory exhaustion ===

#[test]
fn test_rejects_lpad_huge_length() {
assert!(validate_query("SELECT lpad('x', 999999999) FROM blocks").is_err());
}

#[test]
fn test_rejects_rpad_huge_length() {
assert!(validate_query("SELECT rpad('x', 999999999) FROM blocks").is_err());
}

#[test]
fn test_rejects_repeat_huge_count() {
assert!(validate_query("SELECT repeat('x', 999999999) FROM blocks").is_err());
}

#[test]
fn test_allows_lpad_rpad_small_length() {
assert!(validate_query("SELECT lpad(hash, 66, '0') FROM blocks").is_ok());
assert!(validate_query("SELECT rpad(hash, 66, '0') FROM blocks").is_ok());
assert!(validate_query("SELECT repeat('0', 10) FROM blocks").is_ok());
}
}
Loading