Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ version = "0.22"
features = ["extension-module"]

[dependencies.sqlparser]
version = "0.56.0"
version = "0.60.0"
features = ["serde", "visitor"]
10 changes: 3 additions & 7 deletions examples/depgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"""

import argparse
import json
import os
from glob import glob
from typing import List

Expand All @@ -16,6 +14,7 @@
parser.add_argument("--path", "-p", type=str, help="The path to process queries for.")
parser.add_argument("--dialect", "-d", type=str, help="The dialect to use.")


def get_sql_files(path: str) -> List[str]:
return glob(path + "/**/*.sql")

Expand All @@ -31,7 +30,6 @@ def get_key_recursive(search_dict, field):
fields_found = []

for key, value in search_dict.items():

if key == field:
fields_found.append(value)

Expand All @@ -51,7 +49,6 @@ def get_key_recursive(search_dict, field):


def get_tables_in_query(SQL: str, dialect: str) -> List[str]:

res = sqloxide.parse_sql(sql=SQL, dialect=dialect)
tables = get_key_recursive(res[0]["Query"], "Table")

Expand All @@ -64,11 +61,10 @@ def get_tables_in_query(SQL: str, dialect: str) -> List[str]:


if __name__ == "__main__":

args = parser.parse_args()

files = get_sql_files(args.path)
print(f'Parsing using dialect: {args.dialect}')
print(f"Parsing using dialect: {args.dialect}")

result_dict = dict()

Expand All @@ -87,7 +83,7 @@ def get_tables_in_query(SQL: str, dialect: str) -> List[str]:
dot = Digraph(engine="dot")
dot.attr(rankdir="LR")
dot.attr(splines="ortho")
dot.node_attr['shape'] = 'box'
dot.node_attr["shape"] = "box"

for view, tables in result_dict.items():
view = view[:-4]
Expand Down
14 changes: 9 additions & 5 deletions justfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
benchmark: build
uvx poetry run pytest tests/benchmark.py
benchmark:
uv sync
uv run maturin develop --release
uv run pytest tests/benchmark.py

test:
uvx poetry run pytest tests/
test:
uv sync
uv run maturin develop
uv run pytest tests/

build:
uvx poetry build
uv run maturin build --release
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ classifiers = [
"License :: OSI Approved :: MIT License",
]

[dependency-groups]
dev = [
# build
"maturin",
# test
"pytest",
"pytest-benchmark",
"pytest-subtests",
# benchmark
"sqlglot",
"sqlparse",
]

[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
6 changes: 4 additions & 2 deletions sqloxide.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Select(TypedDict("Select", {"from": list[TableWithJoins]})):
class Insert(TypedDict("Insert", {"or": Any | None})):
"""
An INSERT statement.

See https://docs.rs/sqlparser/0.51.0/sqlparser/ast/struct.Insert.html
"""

Expand All @@ -163,7 +163,9 @@ class Insert(TypedDict("Insert", {"or": Any | None})):
partitioned: Any | None
after_columns: list[Any]
table: bool
on: dict[str, Any] | None # e.g. {"OnConflict": {"conflict_target": None, "action": "DoNothing"}},
on: (
dict[str, Any] | None
) # e.g. {"OnConflict": {"conflict_target": None, "action": "DoNothing"}},
returning: Any | None
replace_into: bool
priority: Any | None
Expand Down
2 changes: 1 addition & 1 deletion sqloxide/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .sqloxide import *
from .sqloxide import * # noqa: F403
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use visitor::{extract_expressions, extract_relations, mutate_expressions, mutate
/// Available `dialects`: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/mod.rs#L189-L206
#[pyfunction]
#[pyo3(text_signature = "(sql, dialect)")]
fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult<PyObject> {
fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult<Py<PyAny>> {
let chosen_dialect = dialect_from_str(dialect).unwrap_or_else(|| {
println!("The dialect you chose was not recognized, falling back to 'generic'");
Box::new(GenericDialect {})
Expand Down
37 changes: 19 additions & 18 deletions src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ where

#[pyfunction]
#[pyo3(text_signature = "(parsed_query)")]
pub fn extract_relations(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<PyObject> {
pub fn extract_relations(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let statements = depythonize_query(parsed_query)?;

let mut relations = Vec::new();
for statement in statements {
visit_relations(&statement, |relation| {
let _ = visit_relations(&statement, |relation| {
relations.push(relation.clone());
ControlFlow::<()>::Continue(())
});
Expand All @@ -59,20 +59,21 @@ pub fn mutate_relations(_py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bou
let mut statements = depythonize_query(parsed_query)?;

for statement in &mut statements {
visit_relations_mut(statement, |table| {
let _ = visit_relations_mut(statement, |table| {
for section in &mut table.0 {
let ObjectNamePart::Identifier(ident) = section;
let val = match func.call1((ident.value.clone(),)) {
Ok(val) => val,
Err(e) => {
let msg = e.to_string();
return ControlFlow::Break(PyValueError::new_err(format!(
"Python object serialization failed.\n\t{msg}"
)));
}
};

ident.value = val.to_string();
if let ObjectNamePart::Identifier(ident) = section {
let val = match func.call1((ident.value.clone(),)) {
Ok(val) => val,
Err(e) => {
let msg = e.to_string();
return ControlFlow::Break(PyValueError::new_err(format!(
"Python object serialization failed.\n\t{msg}"
)));
}
};

ident.value = val.to_string();
}
}
ControlFlow::Continue(())
});
Expand All @@ -90,7 +91,7 @@ pub fn mutate_expressions(py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bo
let mut statements: Vec<Statement> = depythonize_query(parsed_query)?;

for statement in &mut statements {
visit_expressions_mut(statement, |expr| {
let _ = visit_expressions_mut(statement, |expr| {
let converted_expr = match pythonize::pythonize(py, expr) {
Ok(val) => val,
Err(e) => {
Expand Down Expand Up @@ -133,12 +134,12 @@ pub fn mutate_expressions(py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bo

#[pyfunction]
#[pyo3(text_signature = "(parsed_query)")]
pub fn extract_expressions(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<PyObject> {
pub fn extract_expressions(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let statements: Vec<Statement> = depythonize_query(parsed_query)?;

let mut expressions = Vec::new();
for statement in statements {
visit_expressions(&statement, |expr| {
let _ = visit_expressions(&statement, |expr| {
expressions.push(expr.clone());
ControlFlow::<()>::Continue(())
});
Expand Down
Empty file added tests/__init__.py
Empty file.
12 changes: 0 additions & 12 deletions tests/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import pytest

from sqloxide import parse_sql
import sqlparse
import sqlglot
import json
import moz_sql_parser

TEST_SQL = """
SELECT employee.first_name, employee.last_name,
Expand All @@ -24,10 +20,6 @@ def bench_sqlparser():
return sqlparse.parse(TEST_SQL)[0]


def bench_mozsqlparser():
return json.dumps(moz_sql_parser.parse(TEST_SQL))


def bench_sqlglot():
return sqlglot.parse(TEST_SQL, error_level=sqlglot.ErrorLevel.IGNORE)

Expand All @@ -40,9 +32,5 @@ def test_sqlparser(benchmark):
benchmark(bench_sqlparser)


def test_mozsqlparser(benchmark):
benchmark(bench_mozsqlparser)


def test_sqlglot(benchmark):
benchmark(bench_sqlglot)
4 changes: 2 additions & 2 deletions tests/test_sqloxide.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def func(x):

ast = parse_sql(sql=SQL, dialect="ansi")
assert mutate_relations(parsed_query=ast, func=func) == [
'SELECT employee.first_name, employee.last_name, c.start_time, c.end_time, call_outcome.outcome_text FROM employee INNER JOIN "call2"."call2"."call2" AS c ON c.employee_id = employee.id INNER JOIN call2_outcome ON c.call_outcome_id = call_outcome.id ORDER BY c.start_time ASC'
'SELECT employee.first_name, employee.last_name, c.start_time, c.end_time, call_outcome.outcome_text FROM employee INNER JOIN "call2"."call2"."call2" c ON c.employee_id = employee.id INNER JOIN call2_outcome ON c.call_outcome_id = call_outcome.id ORDER BY c.start_time ASC'
]


Expand All @@ -87,7 +87,7 @@ def func(x):
ast = parse_sql(sql=SQL, dialect="ansi")
result = mutate_expressions(parsed_query=ast, func=func)
assert result == [
'SELECT EMPLOYEE.FIRST_NAME, EMPLOYEE.LAST_NAME, C.START_TIME, C.END_TIME, CALL_OUTCOME.OUTCOME_TEXT FROM employee INNER JOIN "call"."call"."call" AS c ON C.EMPLOYEE_ID = EMPLOYEE.ID INNER JOIN call_outcome ON C.CALL_OUTCOME_ID = CALL_OUTCOME.ID ORDER BY C.START_TIME ASC'
'SELECT EMPLOYEE.FIRST_NAME, EMPLOYEE.LAST_NAME, C.START_TIME, C.END_TIME, CALL_OUTCOME.OUTCOME_TEXT FROM employee INNER JOIN "call"."call"."call" c ON C.EMPLOYEE_ID = EMPLOYEE.ID INNER JOIN call_outcome ON C.CALL_OUTCOME_ID = CALL_OUTCOME.ID ORDER BY C.START_TIME ASC'
]


Expand Down