diff --git a/.gitignore b/.gitignore index b06b641..5f03a2e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ Thumbs.db # Seeds *.dump docker/postgres-pgduckdb/tidx-abi/target/ +.worktrees/ diff --git a/db/api_role.sql b/db/api_role.sql new file mode 100644 index 0000000..8348e22 --- /dev/null +++ b/db/api_role.sql @@ -0,0 +1,36 @@ +-- Create a read-only role for API query connections. +-- The API should connect as this role to provide defense-in-depth +-- against SQL injection, even if the query validator is bypassed. +-- +-- Modeled after golden-axe's uapi role: +-- https://github.com/indexsupply/golden-axe/blob/master/be/src/sql/roles.sql +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'tidx_api') THEN + CREATE ROLE tidx_api WITH LOGIN PASSWORD 'tidx_api' NOSUPERUSER NOCREATEDB NOCREATEROLE; + END IF; +END $$; + +-- Revoke all default privileges first (defense-in-depth) +REVOKE ALL ON ALL TABLES IN SCHEMA public FROM tidx_api; +REVOKE EXECUTE ON ALL FUNCTIONS IN SCHEMA public FROM tidx_api; + +-- Grant read-only access to indexed tables only +GRANT USAGE ON SCHEMA public TO tidx_api; +GRANT SELECT ON blocks, txs, logs, receipts TO tidx_api; + +-- Grant execute only on ABI decode helper functions +GRANT EXECUTE ON FUNCTION abi_uint(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_int(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bool(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bytes(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_string(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_uint(bytea) TO tidx_api; + +-- Resource limits to prevent DoS even if validator is bypassed +ALTER ROLE tidx_api SET statement_timeout = '30s'; +ALTER ROLE tidx_api SET work_mem = '256MB'; +ALTER ROLE tidx_api SET temp_file_limit = '512MB'; +ALTER ROLE tidx_api CONNECTION LIMIT 64; diff --git a/src/api/mod.rs b/src/api/mod.rs index 07a05b6..2530017 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -477,7 +477,16 @@ async fn handle_query_live( } else { let catch_up_start = last_block_num + 1; for block_num in catch_up_start..=end { - let filtered_sql = inject_block_filter(&sql, block_num); + let filtered_sql = match inject_block_filter(&sql, block_num) { + Ok(s) => s, + Err(e) => { + yield Ok(SseEvent::default() + .event("error") + .json_data(serde_json::json!({ "ok": false, "error": e.to_string() })) + .unwrap()); + return; + } + }; match crate::service::execute_query_postgres(&pool, &filtered_sql, signature.as_deref(), &options).await { Ok(result) => { yield Ok(SseEvent::default() @@ -516,50 +525,85 @@ async fn handle_query_live( /// Inject a block number filter into SQL query for live streaming. /// Transforms queries to only return data for the specific block. /// Uses 'num' for blocks table, 'block_num' for txs/logs tables. +/// +/// Uses sqlparser AST manipulation to safely add the filter condition, +/// avoiding SQL injection risks from string-based splicing. #[doc(hidden)] -pub fn inject_block_filter(sql: &str, block_num: u64) -> String { - let sql_upper = sql.to_uppercase(); - - // Determine column name based on table being queried - let col = if sql_upper.contains("FROM BLOCKS") || sql_upper.contains("FROM \"BLOCKS\"") { - "num" - } else { - "block_num" +pub fn inject_block_filter(sql: &str, block_num: u64) -> Result { + use sqlparser::ast::{ + BinaryOperator, Expr, Ident, SetExpr, Statement, Value, }; - - // Find WHERE clause position - if let Some(where_pos) = sql_upper.find("WHERE") { - // Insert after WHERE - let insert_pos = where_pos + 5; - format!( - "{} {} = {} AND {}", - &sql[..insert_pos], - col, - block_num, - &sql[insert_pos..] - ) - } else if let Some(order_pos) = sql_upper.find("ORDER BY") { - // Insert WHERE before ORDER BY - format!( - "{} WHERE {} = {} {}", - &sql[..order_pos], - col, - block_num, - &sql[order_pos..] - ) - } else if let Some(limit_pos) = sql_upper.find("LIMIT") { - // Insert WHERE before LIMIT - format!( - "{} WHERE {} = {} {}", - &sql[..limit_pos], - col, - block_num, - &sql[limit_pos..] - ) - } else { - // Append WHERE at end - format!("{sql} WHERE {col} = {block_num}") + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| ApiError::BadRequest(format!("SQL parse error: {e}")))?; + + if statements.len() != 1 { + return Err(ApiError::BadRequest( + "Live mode requires exactly one SQL statement".to_string(), + )); } + + let stmt = &mut statements[0]; + let query = match stmt { + Statement::Query(q) => q, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a SELECT query".to_string(), + )) + } + }; + + let select = match query.body.as_mut() { + SetExpr::Select(s) => s, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a simple SELECT query (UNION/INTERSECT not supported)" + .to_string(), + )) + } + }; + + let table_name: String = select + .from + .first() + .and_then(|twj| match &twj.relation { + sqlparser::ast::TableFactor::Table { name, .. } => { + name.0.last().and_then(|part| part.as_ident()).map(|ident| ident.value.to_lowercase()) + } + _ => None, + }) + .ok_or_else(|| { + ApiError::BadRequest( + "Live mode requires a query with a FROM table clause".to_string(), + ) + })?; + + let col_name = if table_name == "blocks" { "num" } else { "block_num" }; + + let col_expr = Expr::CompoundIdentifier(vec![ + Ident::new(&table_name), + Ident::new(col_name), + ]); + + let block_filter = Expr::BinaryOp { + left: Box::new(col_expr), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number(block_num.to_string(), false).into())), + }; + + select.selection = Some(match select.selection.take() { + Some(existing) => Expr::BinaryOp { + left: Box::new(Expr::Nested(Box::new(existing))), + op: BinaryOperator::And, + right: Box::new(block_filter), + }, + None => block_filter, + }); + + Ok(stmt.to_string()) } /// Rewrite analytics table references to include chain-specific database prefix. @@ -599,6 +643,19 @@ pub enum ApiError { NotFound(String), } +impl std::fmt::Display for ApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiError::BadRequest(msg) => write!(f, "{msg}"), + ApiError::Timeout => write!(f, "Query timeout"), + ApiError::QueryError(msg) => write!(f, "{msg}"), + ApiError::Internal(msg) => write!(f, "{msg}"), + ApiError::Forbidden(msg) => write!(f, "{msg}"), + ApiError::NotFound(msg) => write!(f, "{msg}"), + } + } +} + impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let (status, message) = match self { diff --git a/src/db/schema.rs b/src/db/schema.rs index 1c51511..9cc11af 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -39,6 +39,9 @@ pub async fn run_migrations(pool: &Pool) -> Result<()> { // Load any optional extensions conn.batch_execute(include_str!("../../db/extensions.sql")).await?; + // Create read-only API role with SELECT-only access to indexed tables + conn.batch_execute(include_str!("../../db/api_role.sql")).await?; + Ok(()) } diff --git a/src/query/validator.rs b/src/query/validator.rs index 87a2a88..373d053 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use anyhow::{anyhow, Result}; use sqlparser::ast::{ Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, ObjectName, Query, SetExpr, @@ -6,6 +8,13 @@ use sqlparser::ast::{ use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; +const ALLOWED_TABLES: &[&str] = &[ + "blocks", + "txs", + "logs", + "receipts", +]; + /// Validates that a SQL query is safe to execute. /// /// Rejects: @@ -15,6 +24,11 @@ use sqlparser::parser::Parser; /// - Dangerous functions (pg_sleep, read_csv, pg_read_file, etc.) /// - System catalog access pub fn validate_query(sql: &str) -> Result<()> { + const MAX_QUERY_LENGTH: usize = 16_384; + if sql.len() > MAX_QUERY_LENGTH { + return Err(anyhow!("Query too long (max {MAX_QUERY_LENGTH} bytes)")); + } + let dialect = GenericDialect {}; let statements = Parser::parse_sql(&dialect, sql) .map_err(|e| anyhow!("SQL parse error: {e}"))?; @@ -30,107 +44,197 @@ pub fn validate_query(sql: &str) -> Result<()> { let stmt = &statements[0]; match stmt { - Statement::Query(query) => validate_query_ast(query), + Statement::Query(query) => { + let cte_names = extract_cte_names(query); + validate_query_ast(query, &cte_names) + } _ => Err(anyhow!("Only SELECT queries are allowed")), } } -fn validate_query_ast(query: &Query) -> Result<()> { - // Check CTEs for data-modifying statements +fn extract_cte_names(query: &Query) -> HashSet { + let mut names = HashSet::new(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + names.insert(cte.alias.name.value.to_lowercase()); + } + } + names +} + +fn validate_query_ast(query: &Query, cte_names: &HashSet) -> Result<()> { + // Block recursive CTEs (can cause endless loops / resource exhaustion) + if let Some(with) = &query.with { + if with.recursive { + return Err(anyhow!("Recursive CTEs are not allowed")); + } + } + + let mut all_cte_names = cte_names.clone(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + all_cte_names.insert(cte.alias.name.value.to_lowercase()); + } + } + for cte in &query.with.as_ref().map_or(vec![], |w| w.cte_tables.clone()) { - validate_query_ast(&cte.query)?; + validate_query_ast(&cte.query, &all_cte_names)?; } - validate_set_expr(&query.body) + validate_set_expr(&query.body, &all_cte_names)?; + + if let Some(order_by) = &query.order_by { + if let sqlparser::ast::OrderByKind::Expressions(exprs) = &order_by.kind { + for order_expr in exprs { + validate_expr(&order_expr.expr, &all_cte_names)?; + } + } + } + + if let Some(limit_clause) = &query.limit_clause { + match limit_clause { + sqlparser::ast::LimitClause::LimitOffset { limit, offset, limit_by } => { + if let Some(limit) = limit { + validate_expr(limit, &all_cte_names)?; + } + if let Some(offset) = offset { + validate_expr(&offset.value, &all_cte_names)?; + } + for expr in limit_by { + validate_expr(expr, &all_cte_names)?; + } + } + sqlparser::ast::LimitClause::OffsetCommaLimit { offset, limit } => { + validate_expr(offset, &all_cte_names)?; + validate_expr(limit, &all_cte_names)?; + } + } + } + + if !query.locks.is_empty() { + return Err(anyhow!("FOR UPDATE/FOR SHARE is not allowed")); + } + + Ok(()) } -fn validate_set_expr(set_expr: &SetExpr) -> Result<()> { +fn validate_set_expr(set_expr: &SetExpr, cte_names: &HashSet) -> Result<()> { match set_expr { SetExpr::Select(select) => { - // Validate FROM clause + // Reject SELECT INTO (creates objects) + if select.into.is_some() { + return Err(anyhow!("SELECT INTO is not allowed")); + } + for table in &select.from { - validate_table_with_joins(table)?; + validate_table_with_joins(table, cte_names)?; } - // Validate SELECT expressions for item in &select.projection { if let sqlparser::ast::SelectItem::UnnamedExpr(expr) | sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } = item { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; } } - // Validate WHERE clause if let Some(selection) = &select.selection { - validate_expr(selection)?; + validate_expr(selection, cte_names)?; + } + + // Validate GROUP BY expressions + if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by { + for expr in exprs { + validate_expr(expr, cte_names)?; + } + } + + // Validate HAVING + if let Some(having) = &select.having { + validate_expr(having, cte_names)?; } Ok(()) } - SetExpr::Query(q) => validate_query_ast(q), + SetExpr::Query(q) => validate_query_ast(q, cte_names), SetExpr::SetOperation { left, right, .. } => { - validate_set_expr(left)?; - validate_set_expr(right) + validate_set_expr(left, cte_names)?; + validate_set_expr(right, cte_names) + } + SetExpr::Values(values) => { + // Validate all expressions in VALUES rows to prevent function call bypass + for row in &values.rows { + for expr in row { + validate_expr(expr, cte_names)?; + } + } + Ok(()) } - SetExpr::Values(_) => Ok(()), SetExpr::Insert(_) => Err(anyhow!("INSERT not allowed")), SetExpr::Update(_) => Err(anyhow!("UPDATE not allowed")), SetExpr::Delete(_) => Err(anyhow!("DELETE not allowed")), SetExpr::Merge(_) => Err(anyhow!("MERGE not allowed")), - SetExpr::Table(_) => Ok(()), + SetExpr::Table(_) => Err(anyhow!("TABLE statement is not allowed")), } } -fn validate_table_with_joins(table: &TableWithJoins) -> Result<()> { - validate_table_factor(&table.relation)?; +fn validate_table_with_joins(table: &TableWithJoins, cte_names: &HashSet) -> Result<()> { + validate_table_factor(&table.relation, cte_names)?; for join in &table.joins { - validate_table_factor(&join.relation)?; + validate_table_factor(&join.relation, cte_names)?; + // Validate JOIN ON expressions + let constraint = match &join.join_operator { + sqlparser::ast::JoinOperator::Join(c) + | sqlparser::ast::JoinOperator::Inner(c) + | sqlparser::ast::JoinOperator::Left(c) + | sqlparser::ast::JoinOperator::LeftOuter(c) + | sqlparser::ast::JoinOperator::Right(c) + | sqlparser::ast::JoinOperator::RightOuter(c) + | sqlparser::ast::JoinOperator::FullOuter(c) + | sqlparser::ast::JoinOperator::CrossJoin(c) + | sqlparser::ast::JoinOperator::Semi(c) + | sqlparser::ast::JoinOperator::LeftSemi(c) + | sqlparser::ast::JoinOperator::RightSemi(c) + | sqlparser::ast::JoinOperator::Anti(c) + | sqlparser::ast::JoinOperator::LeftAnti(c) + | sqlparser::ast::JoinOperator::RightAnti(c) => Some(c), + _ => None, + }; + if let Some(sqlparser::ast::JoinConstraint::On(expr)) = constraint { + validate_expr(expr, cte_names)?; + } } Ok(()) } -fn validate_table_factor(factor: &TableFactor) -> Result<()> { +fn validate_table_factor(factor: &TableFactor, cte_names: &HashSet) -> Result<()> { match factor { TableFactor::Table { name, args, .. } => { - // Check if this is a table-valued function like read_csv(...) if args.is_some() { let func_name = name.to_string().to_lowercase(); if is_dangerous_table_function(&func_name) { return Err(anyhow!("Table function '{func_name}' is not allowed")); } } - validate_table_name(name) + validate_table_name(name, cte_names) } - TableFactor::Derived { subquery, .. } => validate_query_ast(subquery), - TableFactor::TableFunction { expr, .. } => { - // Block table functions that can read filesystem - if let Expr::Function(func) = expr { - let func_name = func.name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - } - Ok(()) + TableFactor::Derived { subquery, .. } => validate_query_ast(subquery, cte_names), + TableFactor::TableFunction { .. } => { + Err(anyhow!("Table functions in FROM clause are not allowed")) } - TableFactor::Function { name, .. } => { - let func_name = name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - Ok(()) + TableFactor::Function { .. } => { + Err(anyhow!("Table functions in FROM clause are not allowed")) } TableFactor::NestedJoin { table_with_joins, .. } => { - validate_table_with_joins(table_with_joins) + validate_table_with_joins(table_with_joins, cte_names) } _ => Ok(()), } } -fn validate_table_name(name: &ObjectName) -> Result<()> { +fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result<()> { let full_name = name.to_string().to_lowercase(); - // Block system catalogs const BLOCKED_SCHEMAS: &[&str] = &[ "pg_catalog", "information_schema", @@ -144,7 +248,6 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { } } - // Block specific dangerous tables const BLOCKED_TABLES: &[&str] = &[ "pg_stat_activity", "pg_settings", @@ -160,65 +263,95 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { } } - Ok(()) + let bare_name = name.0.last() + .and_then(|part| part.as_ident()) + .map(|ident| ident.value.to_lowercase()) + .unwrap_or_default(); + + if ALLOWED_TABLES.contains(&bare_name.as_str()) { + return Ok(()); + } + + if cte_names.contains(&bare_name) { + return Ok(()); + } + + Err(anyhow!("Access to table '{bare_name}' is not allowed")) } -fn validate_expr(expr: &Expr) -> Result<()> { +fn validate_expr(expr: &Expr, cte_names: &HashSet) -> Result<()> { match expr { - Expr::Function(func) => validate_function(func), - Expr::Subquery(q) => validate_query_ast(q), - Expr::InSubquery { subquery, .. } => validate_query_ast(subquery), - Expr::Exists { subquery, .. } => validate_query_ast(subquery), + Expr::Function(func) => validate_function(func, cte_names), + Expr::Subquery(q) => validate_query_ast(q, cte_names), + Expr::InSubquery { subquery, .. } => validate_query_ast(subquery, cte_names), + Expr::Exists { subquery, .. } => validate_query_ast(subquery, cte_names), Expr::BinaryOp { left, right, .. } => { - validate_expr(left)?; - validate_expr(right) + validate_expr(left, cte_names)?; + validate_expr(right, cte_names) } - Expr::UnaryOp { expr, .. } => validate_expr(expr), + Expr::UnaryOp { expr, .. } => validate_expr(expr, cte_names), Expr::Between { expr, low, high, .. } => { - validate_expr(expr)?; - validate_expr(low)?; - validate_expr(high) + validate_expr(expr, cte_names)?; + validate_expr(low, cte_names)?; + validate_expr(high, cte_names) } Expr::Case { operand, conditions, else_result, .. } => { if let Some(op) = operand { - validate_expr(op)?; + validate_expr(op, cte_names)?; } for case_when in conditions { - validate_expr(&case_when.condition)?; - validate_expr(&case_when.result)?; + validate_expr(&case_when.condition, cte_names)?; + validate_expr(&case_when.result, cte_names)?; } if let Some(else_r) = else_result { - validate_expr(else_r)?; + validate_expr(else_r, cte_names)?; } Ok(()) } - Expr::Cast { expr, .. } => validate_expr(expr), - Expr::Nested(e) => validate_expr(e), + Expr::Cast { expr, .. } => validate_expr(expr, cte_names), + Expr::Nested(e) => validate_expr(e, cte_names), Expr::InList { expr, list, .. } => { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; for item in list { - validate_expr(item)?; + validate_expr(item, cte_names)?; } Ok(()) } - _ => Ok(()), + Expr::IsNull(e) + | Expr::IsNotNull(e) + | Expr::IsTrue(e) + | Expr::IsFalse(e) + | Expr::IsNotTrue(e) + | Expr::IsNotFalse(e) + | Expr::IsUnknown(e) + | Expr::IsNotUnknown(e) => validate_expr(e, cte_names), + Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => { + validate_expr(expr, cte_names)?; + validate_expr(pattern, cte_names) + } + Expr::AnyOp { right, .. } | Expr::AllOp { right, .. } => validate_expr(right, cte_names), + Expr::Value(_) + | Expr::Identifier(_) + | Expr::CompoundIdentifier(_) + | Expr::Wildcard(_) + | Expr::QualifiedWildcard(_, _) => Ok(()), + other => Err(anyhow!("Expression type not allowed: {other}")), } } -fn validate_function(func: &Function) -> Result<()> { +fn validate_function(func: &Function, cte_names: &HashSet) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); if is_dangerous_function(&func_name) { return Err(anyhow!("Function '{func_name}' is not allowed")); } - // Recursively validate function arguments if let FunctionArguments::List(arg_list) = &func.args { for arg in &arg_list.args { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) | FunctionArg::Named { arg: FunctionArgExpr::Expr(expr), .. } = arg { - validate_expr(expr)?; + validate_expr(expr, cte_names)?; } } } @@ -250,6 +383,30 @@ fn is_dangerous_function(name: &str) -> bool { "lo_export", // PostgreSQL command execution "pg_execute_server_program", + // PostgreSQL dblink (remote connections) + "dblink", + "dblink_exec", + "dblink_connect", + "dblink_send_query", + "dblink_get_result", + // PostgreSQL large object access + "lo_get", + "lo_open", + "lo_close", + "loread", + "lowrite", + "lo_creat", + "lo_create", + "lo_unlink", + "lo_put", + // PostgreSQL set-returning functions (DoS via row generation) + "generate_series", + // PostgreSQL admin extension functions (file access) + "pg_file_read", + "pg_file_write", + "pg_file_rename", + "pg_file_unlink", + "pg_logdir_ls", // ClickHouse system functions "system.flush_logs", "system.reload_config", @@ -403,4 +560,147 @@ mod tests { fn test_rejects_nested_dangerous_function() { assert!(validate_query("SELECT COALESCE(pg_sleep(1), 0)").is_err()); } + + #[test] + fn test_rejects_sync_state() { + assert!(validate_query("SELECT * FROM sync_state").is_err()); + } + + #[test] + fn test_rejects_pg_tables() { + assert!(validate_query("SELECT * FROM pg_tables").is_err()); + } + + #[test] + fn test_rejects_unknown_table() { + assert!(validate_query("SELECT * FROM some_random_table").is_err()); + } + + #[test] + fn test_allows_cte_defined_table() { + assert!(validate_query( + "WITH my_cte AS (SELECT * FROM blocks) SELECT * FROM my_cte" + ) + .is_ok()); + } + + #[test] + fn test_rejects_dblink() { + assert!(validate_query("SELECT * FROM dblink('host=evil dbname=secrets', 'SELECT * FROM passwords')").is_err()); + assert!(validate_query("SELECT dblink_connect('myconn', 'host=evil')").is_err()); + assert!(validate_query("SELECT dblink_exec('myconn', 'DROP TABLE blocks')").is_err()); + } + + #[test] + fn test_allows_schema_qualified_tables() { + assert!(validate_query("SELECT * FROM public.blocks").is_ok()); + } + + #[test] + fn test_rejects_analytics_tables_on_postgres() { + // Analytics tables (token_holders, token_balances) are ClickHouse-only + // and go through a separate code path that doesn't use this validator + assert!(validate_query("SELECT * FROM token_holders").is_err()); + assert!(validate_query("SELECT * FROM token_balances").is_err()); + } + + #[test] + fn test_rejects_recursive_cte() { + assert!(validate_query( + "WITH RECURSIVE r AS (SELECT 1 AS n UNION ALL SELECT n+1 FROM r) SELECT * FROM r" + ).is_err()); + } + + #[test] + fn test_rejects_generate_series() { + assert!(validate_query("SELECT generate_series(1, 1000000000)").is_err()); + assert!(validate_query("SELECT * FROM blocks WHERE num IN (SELECT generate_series(1, 1000000))").is_err()); + } + + #[test] + fn test_rejects_values_function_bypass() { + assert!(validate_query("VALUES (pg_sleep(10))").is_err()); + assert!(validate_query("VALUES (pg_read_file('/etc/passwd'))").is_err()); + } + + #[test] + fn test_rejects_table_statement() { + assert!(validate_query("TABLE blocks").is_err()); + assert!(validate_query("TABLE pg_shadow").is_err()); + } + + #[test] + fn test_rejects_select_into() { + assert!(validate_query("SELECT * INTO newtable FROM blocks").is_err()); + } + + #[test] + fn test_rejects_lo_functions() { + assert!(validate_query("SELECT lo_get(12345)").is_err()); + assert!(validate_query("SELECT lo_open(12345, 262144)").is_err()); + } + + #[test] + fn test_rejects_admin_file_functions() { + assert!(validate_query("SELECT pg_file_read('/etc/passwd', 0, 1000)").is_err()); + assert!(validate_query("SELECT pg_file_write('/tmp/evil', 'data', false)").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_having() { + assert!(validate_query("SELECT COUNT(*) FROM blocks GROUP BY num HAVING pg_sleep(1) IS NOT NULL").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_join_on() { + assert!(validate_query( + "SELECT * FROM blocks JOIN txs ON pg_sleep(1) IS NOT NULL" + ).is_err()); + } + + #[test] + fn test_allows_simple_values() { + assert!(validate_query("VALUES (1, 'hello'), (2, 'world')").is_ok()); + } + + #[test] + fn test_rejects_long_query() { + let long_query = format!("SELECT * FROM blocks WHERE num > {}", "1".repeat(20000)); + assert!(validate_query(&long_query).is_err()); + } + + #[test] + fn test_rejects_for_update() { + assert!(validate_query("SELECT * FROM blocks FOR UPDATE").is_err()); + assert!(validate_query("SELECT * FROM blocks FOR SHARE").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_order_by() { + assert!(validate_query( + "SELECT * FROM blocks ORDER BY pg_sleep(1)" + ).is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_limit() { + assert!(validate_query( + "SELECT * FROM blocks LIMIT (SELECT pg_sleep(1))" + ).is_err()); + } + + #[test] + fn test_rejects_table_function_in_from() { + assert!(validate_query("SELECT * FROM generate_series(1, 100) AS t").is_err()); + } + + #[test] + fn test_allows_normal_order_by() { + assert!(validate_query("SELECT * FROM blocks ORDER BY num DESC").is_ok()); + } + + #[test] + fn test_allows_normal_limit_offset() { + assert!(validate_query("SELECT * FROM blocks LIMIT 10 OFFSET 5").is_ok()); + } } diff --git a/tests/api_live_test.rs b/tests/api_live_test.rs index 69e4de0..4608cd4 100644 --- a/tests/api_live_test.rs +++ b/tests/api_live_test.rs @@ -361,37 +361,57 @@ async fn test_query_live_returns_sse() { #[test] fn test_inject_block_filter_blocks_table() { let sql = "SELECT num, hash FROM blocks ORDER BY num DESC LIMIT 1"; - let filtered = inject_block_filter(sql, 100); - assert!(filtered.contains("num = 100"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("blocks.num = 100"), "got: {filtered}"); assert!(filtered.contains("ORDER BY"), "should preserve ORDER BY"); } #[test] fn test_inject_block_filter_txs_table() { let sql = "SELECT * FROM txs ORDER BY block_num DESC LIMIT 10"; - let filtered = inject_block_filter(sql, 200); - assert!(filtered.contains("block_num = 200"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 200).unwrap(); + assert!(filtered.contains("txs.block_num = 200"), "got: {filtered}"); } #[test] fn test_inject_block_filter_logs_table() { let sql = "SELECT * FROM logs WHERE address = '0x123' ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 300); - assert!(filtered.contains("block_num = 300"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 300).unwrap(); + assert!(filtered.contains("logs.block_num = 300"), "got: {filtered}"); assert!(filtered.contains("address = '0x123'"), "should preserve existing WHERE"); } #[test] fn test_inject_block_filter_with_existing_where() { let sql = "SELECT * FROM txs WHERE gas_used > 21000 ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 400); - assert!(filtered.contains("block_num = 400"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 400).unwrap(); + assert!(filtered.contains("txs.block_num = 400"), "got: {filtered}"); assert!(filtered.contains("gas_used > 21000"), "should preserve existing condition"); } #[test] fn test_inject_block_filter_no_order_by() { let sql = "SELECT COUNT(*) FROM blocks LIMIT 1"; - let filtered = inject_block_filter(sql, 500); - assert!(filtered.contains("num = 500"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 500).unwrap(); + assert!(filtered.contains("blocks.num = 500"), "got: {filtered}"); +} + +#[test] +fn test_inject_block_filter_rejects_union() { + let sql = "SELECT * FROM txs UNION SELECT * FROM logs"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_rejects_non_select() { + let sql = "INSERT INTO txs VALUES (1)"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_where_keyword_in_string_literal() { + let sql = "SELECT * FROM txs WHERE input = 'WHERE clause test'"; + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("txs.block_num = 100"), "got: {filtered}"); + assert!(filtered.contains("'WHERE clause test'"), "should preserve string literal"); }