Skip to content

Commit 8536ccd

Browse files
committed
implement parameter re-duplication for mysql
1 parent 91b096f commit 8536ccd

File tree

2 files changed

+171
-28
lines changed

2 files changed

+171
-28
lines changed

src/webserver/database/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ mod syntax_tree;
99
mod error_highlighting;
1010
mod sql_to_json;
1111

12-
pub use sql::{make_placeholder, ParsedSqlFile};
12+
pub use sql::ParsedSqlFile;
13+
use sql::{DbPlaceHolder, DB_PLACEHOLDERS};
14+
use sqlx::any::AnyKind;
1315

1416
pub struct Database {
1517
pub connection: sqlx::AnyPool,
@@ -34,3 +36,18 @@ impl std::fmt::Display for Database {
3436
write!(f, "{:?}", self.connection.any_kind())
3537
}
3638
}
39+
40+
#[inline]
41+
#[must_use]
42+
pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
43+
if let Some((_, placeholder)) =
44+
DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind)
45+
{
46+
match placeholder {
47+
DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"),
48+
DbPlaceHolder::Positional { placeholder } => placeholder.to_string(),
49+
}
50+
} else {
51+
unreachable!("missing db_kind: {db_kind:?} in DB_PLACEHOLDERS ({DB_PLACEHOLDERS:?})")
52+
}
53+
}

src/webserver/database/sql.rs

Lines changed: 153 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,29 @@ fn parse_sql<'a>(
143143
}))
144144
}
145145

146+
fn transform_to_positional_placeholders(stmt: &mut StmtWithParams, db_kind: AnyKind) {
147+
if let Some((_, DbPlaceHolder::Positional { placeholder })) =
148+
DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind)
149+
{
150+
let mut new_params = Vec::new();
151+
let mut query = stmt.query.clone();
152+
while let Some(pos) = query.find(TEMP_PLACEHOLDER_PREFIX) {
153+
let start_of_number = pos + TEMP_PLACEHOLDER_PREFIX.len();
154+
let end = query[start_of_number..]
155+
.find(|c: char| !c.is_ascii_digit())
156+
.map_or(query.len(), |i| start_of_number + i);
157+
let param_idx = query[start_of_number..end]
158+
.parse::<usize>()
159+
.unwrap_or(1)
160+
- 1;
161+
query.replace_range(pos..end, placeholder);
162+
new_params.push(stmt.params[param_idx].clone());
163+
}
164+
stmt.query = query;
165+
stmt.params = new_params;
166+
}
167+
}
168+
146169
fn parse_single_statement(
147170
parser: &mut Parser<'_>,
148171
db_kind: AnyKind,
@@ -161,7 +184,8 @@ fn parse_single_statement(
161184
semicolon = true;
162185
}
163186
let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_kind);
164-
if let Some((variable, value)) = extract_set_variable(&mut stmt, &mut params, db_kind) {
187+
if let Some((variable, mut value)) = extract_set_variable(&mut stmt, &mut params, db_kind) {
188+
transform_to_positional_placeholders(&mut value, db_kind);
165189
return Some(ParsedStatement::SetVariable { variable, value });
166190
}
167191
if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) {
@@ -178,14 +202,16 @@ fn parse_single_statement(
178202
"{stmt}{semicolon}",
179203
semicolon = if semicolon { ";" } else { "" }
180204
);
181-
log::debug!("Final transformed statement: {stmt}");
182-
Some(ParsedStatement::StmtWithParams(StmtWithParams {
205+
let mut stmt_with_params = StmtWithParams {
183206
query,
184207
query_position: extract_query_start(&stmt),
185208
params,
186209
delayed_functions,
187210
json_columns,
188-
}))
211+
};
212+
transform_to_positional_placeholders(&mut stmt_with_params, db_kind);
213+
log::debug!("Final transformed statement: {}", stmt_with_params.query);
214+
Some(ParsedStatement::StmtWithParams(stmt_with_params))
189215
}
190216

191217
fn extract_query_start(stmt: &impl Spanned) -> SourceSpan {
@@ -473,12 +499,45 @@ struct ParameterExtractor {
473499
parameters: Vec<StmtParam>,
474500
}
475501

476-
const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 3] = [
477-
(AnyKind::Sqlite, "?"),
478-
(AnyKind::Postgres, "$"),
479-
(AnyKind::Mssql, "@p"),
502+
#[derive(Debug)]
503+
pub enum DbPlaceHolder {
504+
PrefixedNumber { prefix: &'static str },
505+
Positional { placeholder: &'static str },
506+
}
507+
508+
pub const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 4] = [
509+
(
510+
AnyKind::Sqlite,
511+
DbPlaceHolder::PrefixedNumber { prefix: "?" },
512+
),
513+
(
514+
AnyKind::Postgres,
515+
DbPlaceHolder::PrefixedNumber { prefix: "$" },
516+
),
517+
(
518+
AnyKind::MySql,
519+
DbPlaceHolder::Positional { placeholder: "?" },
520+
),
521+
(
522+
AnyKind::Mssql,
523+
DbPlaceHolder::PrefixedNumber { prefix: "@p" },
524+
),
480525
];
481-
const DEFAULT_PLACEHOLDER: &str = "?";
526+
527+
/// For positional parameters, we use a temporary placeholder during parameter extraction,
528+
/// And then replace it with the actual placeholder during statement rewriting.
529+
const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP";
530+
531+
fn get_placeholder_prefix(db_kind: AnyKind) -> &'static str {
532+
if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS
533+
.iter()
534+
.find(|(kind, _prefix)| *kind == db_kind)
535+
{
536+
prefix
537+
} else {
538+
TEMP_PLACEHOLDER_PREFIX
539+
}
540+
}
482541

483542
impl ParameterExtractor {
484543
fn extract_parameters(
@@ -509,7 +568,7 @@ impl ParameterExtractor {
509568
}
510569

511570
fn make_placeholder_for_index(&self, index: usize) -> Expr {
512-
let name = make_placeholder(self.db_kind, index);
571+
let name = make_tmp_placeholder(self.db_kind, index);
513572
let data_type = match self.db_kind {
514573
AnyKind::MySql => DataType::Char(None),
515574
AnyKind::Mssql => DataType::Varchar(Some(CharacterLength::Max)),
@@ -529,18 +588,13 @@ impl ParameterExtractor {
529588
}
530589

531590
fn is_own_placeholder(&self, param: &str) -> bool {
532-
if let Some((_, prefix)) = PLACEHOLDER_PREFIXES
533-
.iter()
534-
.find(|(kind, _prefix)| *kind == self.db_kind)
535-
{
536-
if let Some(param) = param.strip_prefix(prefix) {
537-
if let Ok(index) = param.parse::<usize>() {
538-
return index <= self.parameters.len() + 1;
539-
}
591+
let prefix = get_placeholder_prefix(self.db_kind);
592+
if let Some(param) = param.strip_prefix(prefix) {
593+
if let Ok(index) = param.parse::<usize>() {
594+
return index <= self.parameters.len() + 1;
540595
}
541-
return false;
542596
}
543-
param == DEFAULT_PLACEHOLDER
597+
return false;
544598
}
545599
}
546600

@@ -728,14 +782,15 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> {
728782

729783
#[inline]
730784
#[must_use]
731-
pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
732-
if let Some((_, prefix)) = PLACEHOLDER_PREFIXES
733-
.iter()
734-
.find(|(kind, _)| *kind == db_kind)
785+
pub fn make_tmp_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
786+
let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) =
787+
DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind)
735788
{
736-
return format!("{prefix}{arg_number}");
737-
}
738-
DEFAULT_PLACEHOLDER.to_string()
789+
prefix
790+
} else {
791+
TEMP_PLACEHOLDER_PREFIX
792+
};
793+
format!("{prefix}{arg_number}")
739794
}
740795

741796
fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option<StmtParam> {
@@ -1415,4 +1470,75 @@ mod test {
14151470
assert!(json_columns.contains(&"item".to_string()));
14161471
assert!(!json_columns.contains(&"title".to_string()));
14171472
}
1473+
1474+
#[test]
1475+
fn test_positional_placeholders() {
1476+
let sql = "select \
1477+
@SQLPAGE_TEMP10 as a1, \
1478+
@SQLPAGE_TEMP9 as a2, \
1479+
@SQLPAGE_TEMP8 as a3, \
1480+
@SQLPAGE_TEMP7 as a4, \
1481+
@SQLPAGE_TEMP6 as a5, \
1482+
@SQLPAGE_TEMP5 as a6, \
1483+
@SQLPAGE_TEMP4 as a7, \
1484+
@SQLPAGE_TEMP3 as a8, \
1485+
@SQLPAGE_TEMP2 as a9, \
1486+
@SQLPAGE_TEMP1 as a10 \
1487+
@SQLPAGE_TEMP10 as a1bis \
1488+
from t";
1489+
let mut stmt = StmtWithParams {
1490+
query: sql.to_string(),
1491+
query_position: SourceSpan {
1492+
start: SourceLocation { line: 1, column: 1 },
1493+
end: SourceLocation { line: 1, column: 1 },
1494+
},
1495+
params: vec![
1496+
StmtParam::PostOrGet("x1".to_string()),
1497+
StmtParam::PostOrGet("x2".to_string()),
1498+
StmtParam::PostOrGet("x3".to_string()),
1499+
StmtParam::PostOrGet("x4".to_string()),
1500+
StmtParam::PostOrGet("x5".to_string()),
1501+
StmtParam::PostOrGet("x6".to_string()),
1502+
StmtParam::PostOrGet("x7".to_string()),
1503+
StmtParam::PostOrGet("x8".to_string()),
1504+
StmtParam::PostOrGet("x9".to_string()),
1505+
StmtParam::PostOrGet("x10".to_string()),
1506+
],
1507+
delayed_functions: vec![],
1508+
json_columns: vec![],
1509+
};
1510+
transform_to_positional_placeholders(&mut stmt, AnyKind::MySql);
1511+
assert_eq!(
1512+
stmt.query,
1513+
"select \
1514+
? as a1, \
1515+
? as a2, \
1516+
? as a3, \
1517+
? as a4, \
1518+
? as a5, \
1519+
? as a6, \
1520+
? as a7, \
1521+
? as a8, \
1522+
? as a9, \
1523+
? as a10 \
1524+
? as a1bis \
1525+
from t"
1526+
);
1527+
assert_eq!(
1528+
stmt.params,
1529+
vec![
1530+
StmtParam::PostOrGet("x10".to_string()),
1531+
StmtParam::PostOrGet("x9".to_string()),
1532+
StmtParam::PostOrGet("x8".to_string()),
1533+
StmtParam::PostOrGet("x7".to_string()),
1534+
StmtParam::PostOrGet("x6".to_string()),
1535+
StmtParam::PostOrGet("x5".to_string()),
1536+
StmtParam::PostOrGet("x4".to_string()),
1537+
StmtParam::PostOrGet("x3".to_string()),
1538+
StmtParam::PostOrGet("x2".to_string()),
1539+
StmtParam::PostOrGet("x1".to_string()),
1540+
StmtParam::PostOrGet("x10".to_string()),
1541+
]
1542+
);
1543+
}
14181544
}

0 commit comments

Comments
 (0)