Skip to content

Commit ee5d57c

Browse files
committed
simplify extract_set_variable
1 parent 9035580 commit ee5d57c

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

src/webserver/database/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl std::fmt::Display for Database {
4141
#[must_use]
4242
pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
4343
if let Some((_, placeholder)) = DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) {
44-
match placeholder {
44+
match *placeholder {
4545
DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"),
4646
DbPlaceHolder::Positional { placeholder } => placeholder.to_string(),
4747
}

src/webserver/database/sql.rs

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,8 @@ fn parse_single_statement(
182182
semicolon = true;
183183
}
184184
let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_kind);
185-
if let Some((variable, mut value)) = extract_set_variable(&mut stmt, &mut params, db_kind) {
186-
transform_to_positional_placeholders(&mut value, db_kind);
187-
return Some(ParsedStatement::SetVariable { variable, value });
185+
if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_kind) {
186+
return Some(parsed);
188187
}
189188
if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) {
190189
return Some(ParsedStatement::CsvImport(csv_import));
@@ -462,7 +461,7 @@ fn extract_set_variable(
462461
stmt: &mut Statement,
463462
params: &mut Vec<StmtParam>,
464463
db_kind: AnyKind,
465-
) -> Option<(StmtParam, StmtWithParams)> {
464+
) -> Option<ParsedStatement> {
466465
if let Statement::SetVariable {
467466
variables: OneOrManyWithParens::One(ObjectName(name)),
468467
value,
@@ -479,29 +478,19 @@ fn extract_set_variable(
479478
let owned_expr = std::mem::replace(value, Expr::Value(Value::Null));
480479
let mut select_stmt: Statement = expr_to_statement(owned_expr);
481480
let delayed_functions = extract_toplevel_functions(&mut select_stmt);
482-
if let Err(err) = validate_function_calls(&mut select_stmt) {
483-
return Some((
484-
variable,
485-
StmtWithParams {
486-
query: format!("SELECT '' WHERE false -- {}", err),
487-
query_position: extract_query_start(&select_stmt),
488-
params: std::mem::take(params),
489-
delayed_functions: vec![],
490-
json_columns: vec![],
491-
},
492-
));
481+
if let Err(err) = validate_function_calls(&select_stmt) {
482+
return Some(ParsedStatement::Error(err));
493483
}
494484
let json_columns = extract_json_columns(&select_stmt, db_kind);
495-
return Some((
496-
variable,
497-
StmtWithParams {
498-
query: select_stmt.to_string(),
499-
query_position: extract_query_start(&select_stmt),
500-
params: std::mem::take(params),
501-
delayed_functions,
502-
json_columns,
503-
},
504-
));
485+
let mut value = StmtWithParams {
486+
query: select_stmt.to_string(),
487+
query_position: extract_query_start(&select_stmt),
488+
params: std::mem::take(params),
489+
delayed_functions,
490+
json_columns,
491+
};
492+
transform_to_positional_placeholders(&mut value, db_kind);
493+
return Some(ParsedStatement::SetVariable { variable, value });
505494
}
506495
}
507496
None
@@ -607,7 +596,7 @@ impl ParameterExtractor {
607596
return index <= self.parameters.len() + 1;
608597
}
609598
}
610-
return false;
599+
false
611600
}
612601
}
613602

@@ -1419,18 +1408,18 @@ mod test {
14191408
json_array(1, 2, 3) AS json_col2,
14201409
(SELECT json_build_object('nested', subq.val)
14211410
FROM (SELECT AVG(x) AS val FROM generate_series(1, 5) x) subq
1422-
) AS json_col3, -- not supported because of the subquery
1423-
CASE
1424-
WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2')
1425-
THEN to_json(ARRAY(SELECT cte_json FROM json_cte))
1426-
ELSE json_build_array()
1427-
END AS json_col4, -- not supported because of the CASE
1428-
json_unknown_fn(regular_column) AS non_json_col,
1429-
CAST(json_col1 AS json) AS json_col6
1430-
FROM some_table
1431-
CROSS JOIN json_cte
1432-
WHERE json_typeof(json_col1) = 'object'
1433-
";
1411+
) AS json_col3, -- not supported because of the subquery
1412+
CASE
1413+
WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2')
1414+
THEN to_json(ARRAY(SELECT cte_json FROM json_cte))
1415+
ELSE json_build_array()
1416+
END AS json_col4, -- not supported because of the CASE
1417+
json_unknown_fn(regular_column) AS non_json_col,
1418+
CAST(json_col1 AS json) AS json_col6
1419+
FROM some_table
1420+
CROSS JOIN json_cte
1421+
WHERE json_typeof(json_col1) = 'object'
1422+
";
14341423

14351424
let stmt = parse_postgres_stmt(sql);
14361425
let json_columns = extract_json_columns(&stmt, AnyKind::Sqlite);
@@ -1570,4 +1559,21 @@ mod test {
15701559
]
15711560
);
15721561
}
1562+
1563+
#[test]
1564+
fn test_set_variable_error_handling() {
1565+
let sql = "set x = db_function(sqlpage.fetch(other_db_function()))";
1566+
for &(dialect, db_kind) in ALL_DIALECTS {
1567+
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
1568+
let stmt = parse_single_statement(&mut parser, db_kind, sql);
1569+
if let Some(ParsedStatement::Error(err)) = stmt {
1570+
assert!(
1571+
err.to_string().contains("Invalid SQLPage function call"),
1572+
"Expected error for invalid function, got: {err}"
1573+
);
1574+
} else {
1575+
panic!("Expected error for invalid function, got: {stmt:#?}");
1576+
}
1577+
}
1578+
}
15731579
}

0 commit comments

Comments
 (0)