@@ -182,9 +182,8 @@ fn parse_single_statement(
182182 semicolon = true ;
183183 }
184184 let mut params = ParameterExtractor :: extract_parameters ( & mut stmt, db_kind) ;
185- if let Some ( ( variable, mut value) ) = extract_set_variable ( & mut stmt, & mut params, db_kind) {
186- transform_to_positional_placeholders ( & mut value, db_kind) ;
187- return Some ( ParsedStatement :: SetVariable { variable, value } ) ;
185+ if let Some ( parsed) = extract_set_variable ( & mut stmt, & mut params, db_kind) {
186+ return Some ( parsed) ;
188187 }
189188 if let Some ( csv_import) = extract_csv_copy_statement ( & mut stmt) {
190189 return Some ( ParsedStatement :: CsvImport ( csv_import) ) ;
@@ -462,7 +461,7 @@ fn extract_set_variable(
462461 stmt : & mut Statement ,
463462 params : & mut Vec < StmtParam > ,
464463 db_kind : AnyKind ,
465- ) -> Option < ( StmtParam , StmtWithParams ) > {
464+ ) -> Option < ParsedStatement > {
466465 if let Statement :: SetVariable {
467466 variables : OneOrManyWithParens :: One ( ObjectName ( name) ) ,
468467 value,
@@ -479,29 +478,19 @@ fn extract_set_variable(
479478 let owned_expr = std:: mem:: replace ( value, Expr :: Value ( Value :: Null ) ) ;
480479 let mut select_stmt: Statement = expr_to_statement ( owned_expr) ;
481480 let delayed_functions = extract_toplevel_functions ( & mut select_stmt) ;
482- if let Err ( err) = validate_function_calls ( & mut select_stmt) {
483- return Some ( (
484- variable,
485- StmtWithParams {
486- query : format ! ( "SELECT '' WHERE false -- {}" , err) ,
487- query_position : extract_query_start ( & select_stmt) ,
488- params : std:: mem:: take ( params) ,
489- delayed_functions : vec ! [ ] ,
490- json_columns : vec ! [ ] ,
491- } ,
492- ) ) ;
481+ if let Err ( err) = validate_function_calls ( & select_stmt) {
482+ return Some ( ParsedStatement :: Error ( err) ) ;
493483 }
494484 let json_columns = extract_json_columns ( & select_stmt, db_kind) ;
495- return Some ( (
496- variable,
497- StmtWithParams {
498- query : select_stmt. to_string ( ) ,
499- query_position : extract_query_start ( & select_stmt) ,
500- params : std:: mem:: take ( params) ,
501- delayed_functions,
502- json_columns,
503- } ,
504- ) ) ;
485+ let mut value = StmtWithParams {
486+ query : select_stmt. to_string ( ) ,
487+ query_position : extract_query_start ( & select_stmt) ,
488+ params : std:: mem:: take ( params) ,
489+ delayed_functions,
490+ json_columns,
491+ } ;
492+ transform_to_positional_placeholders ( & mut value, db_kind) ;
493+ return Some ( ParsedStatement :: SetVariable { variable, value } ) ;
505494 }
506495 }
507496 None
@@ -607,7 +596,7 @@ impl ParameterExtractor {
607596 return index <= self . parameters . len ( ) + 1 ;
608597 }
609598 }
610- return false ;
599+ false
611600 }
612601}
613602
@@ -1419,18 +1408,18 @@ mod test {
14191408 json_array(1, 2, 3) AS json_col2,
14201409 (SELECT json_build_object('nested', subq.val)
14211410 FROM (SELECT AVG(x) AS val FROM generate_series(1, 5) x) subq
1422- ) AS json_col3, -- not supported because of the subquery
1423- CASE
1424- WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2')
1425- THEN to_json(ARRAY(SELECT cte_json FROM json_cte))
1426- ELSE json_build_array()
1427- END AS json_col4, -- not supported because of the CASE
1428- json_unknown_fn(regular_column) AS non_json_col,
1429- CAST(json_col1 AS json) AS json_col6
1430- FROM some_table
1431- CROSS JOIN json_cte
1432- WHERE json_typeof(json_col1) = 'object'
1433- " ;
1411+ ) AS json_col3, -- not supported because of the subquery
1412+ CASE
1413+ WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2')
1414+ THEN to_json(ARRAY(SELECT cte_json FROM json_cte))
1415+ ELSE json_build_array()
1416+ END AS json_col4, -- not supported because of the CASE
1417+ json_unknown_fn(regular_column) AS non_json_col,
1418+ CAST(json_col1 AS json) AS json_col6
1419+ FROM some_table
1420+ CROSS JOIN json_cte
1421+ WHERE json_typeof(json_col1) = 'object'
1422+ " ;
14341423
14351424 let stmt = parse_postgres_stmt ( sql) ;
14361425 let json_columns = extract_json_columns ( & stmt, AnyKind :: Sqlite ) ;
@@ -1570,4 +1559,21 @@ mod test {
15701559 ]
15711560 ) ;
15721561 }
1562+
1563+ #[ test]
1564+ fn test_set_variable_error_handling ( ) {
1565+ let sql = "set x = db_function(sqlpage.fetch(other_db_function()))" ;
1566+ for & ( dialect, db_kind) in ALL_DIALECTS {
1567+ let mut parser = Parser :: new ( dialect) . try_with_sql ( sql) . unwrap ( ) ;
1568+ let stmt = parse_single_statement ( & mut parser, db_kind, sql) ;
1569+ if let Some ( ParsedStatement :: Error ( err) ) = stmt {
1570+ assert ! (
1571+ err. to_string( ) . contains( "Invalid SQLPage function call" ) ,
1572+ "Expected error for invalid function, got: {err}"
1573+ ) ;
1574+ } else {
1575+ panic ! ( "Expected error for invalid function, got: {stmt:#?}" ) ;
1576+ }
1577+ }
1578+ }
15731579}
0 commit comments