Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import contextvars
import decimal
import functools
import logging
Expand Down Expand Up @@ -65,7 +64,6 @@
InsertToTable,
IsDistinctFrom,
Join,
Param,
Random,
Root,
TableAlias,
Expand All @@ -81,7 +79,6 @@
from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip

logger = logging.getLogger("database")
cv_params = contextvars.ContextVar("params")


class CompileError(Exception):
Expand Down Expand Up @@ -115,10 +112,6 @@ class Compiler(AbstractCompiler):
def dialect(self) -> "BaseDialect":
return self.database.dialect

# TODO: DEPRECATED: Remove once the dialect is used directly in all places.
def compile(self, elem, params=None) -> str:
return self.dialect.compile(self, elem, params)

def new_unique_name(self, prefix="tmp") -> str:
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"
Expand Down Expand Up @@ -221,10 +214,7 @@ def parse_table_name(self, name: str) -> DbPath:
"Parse the given table name into a DbPath"
return parse_table_name(name)

def compile(self, compiler: Compiler, elem, params=None) -> str:
if params:
cv_params.set(params)

def compile(self, compiler: Compiler, elem) -> str:
if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
from data_diff.queries.ast_classes import Select

Expand Down Expand Up @@ -268,8 +258,6 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str:
return self.render_cte(c, elem)
elif isinstance(elem, Commit):
return self.render_commit(c, elem)
elif isinstance(elem, Param):
return self.render_param(c, elem)
elif isinstance(elem, NormalizeAsString):
return self.render_normalizeasstring(c, elem)
elif isinstance(elem, ApplyFuncAndNormalizeAsString):
Expand Down Expand Up @@ -369,10 +357,6 @@ def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
def render_commit(self, c: Compiler, elem: Commit) -> str:
return "COMMIT" if not c.database.is_autocommit else SKIP

def render_param(self, c: Compiler, elem: Param) -> str:
params = cv_params.get()
return self._compile(c, params[elem.name])

def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str:
expr = self.compile(c, elem.expr)
return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type)
Expand Down
11 changes: 2 additions & 9 deletions data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _dfs_values(self):
if k == "source_table":
# Skip data-sources, we're only interested in data-parameters
continue
if not isinstance(vs, (list, tuple)):
if not isinstance(vs, list | tuple):
vs = [vs]
for v in vs:
if isinstance(v, ExprNode):
Expand Down Expand Up @@ -689,7 +689,7 @@ def __getattr__(self, name):
return _ResolveColumn(name)

def __getitem__(self, name):
if isinstance(name, (list, tuple)):
if isinstance(name, list | tuple):
return [_ResolveColumn(n) for n in name]
return _ResolveColumn(name)

Expand Down Expand Up @@ -794,10 +794,3 @@ def returning(self, *exprs) -> Self:
@attrs.define(frozen=True, eq=False)
class Commit(Statement):
"""Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP."""


@attrs.define(frozen=True, eq=False)
class Param(ExprNode, ITable): # TODO: Unused?
"""A value placeholder, to be specified at compilation time using the `cv_params` context variable."""

name: str
84 changes: 43 additions & 41 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def test_basic(self):

t = table("point")
t2 = t.select(x=this.x + 1, y=t["y"] + this.x)
assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point"
assert c.dialect.compile(c, t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point"

t = table("point").where(this.x == 1, this.y == 2)
assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)"
assert c.dialect.compile(c, t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)"

t = table("person").where(this.name == "Albert")
self.assertEqual(c.compile(t), "SELECT * FROM person WHERE (name = 'Albert')")
self.assertEqual(c.dialect.compile(c, t), "SELECT * FROM person WHERE (name = 'Albert')")

def test_outerjoin(self):
c = Compiler(MockDatabase())
Expand All @@ -127,7 +127,8 @@ def test_outerjoin(self):
j = outerjoin(a, b).on(a[k] == b[k] for k in keys)

self.assertEqual(
c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)"
c.dialect.compile(c, j),
"SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)",
)

def test_schema(self):
Expand All @@ -137,7 +138,7 @@ def test_schema(self):
# test table
t = table("a", schema=CaseInsensitiveDict(schema))
q = t.select(this.Id, t["COMMENT"])
assert c.compile(q) == "SELECT id, comment FROM a"
assert c.dialect.compile(c, q) == "SELECT id, comment FROM a"

t = table("a", schema=CaseSensitiveDict(schema))
self.assertRaises(KeyError, t.__getitem__, "Id")
Expand Down Expand Up @@ -174,67 +175,67 @@ def test_cte(self):
t3 = t2.select(this.x)

expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
assert normalize_spaces(c.dialect.compile(c, t3)) == expected

# nested cte
c = Compiler(MockDatabase())
t4 = cte(t3).select(this.x)

expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2"
assert normalize_spaces(c.compile(t4)) == expected
assert normalize_spaces(c.dialect.compile(c, t4)) == expected

# parameterized cte
c = Compiler(MockDatabase())
t2 = cte(t.select(this.x), params=["y"])
t3 = t2.select(this.y)

expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
assert normalize_spaces(c.dialect.compile(c, t3)) == expected

def test_funcs(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.order_by(Random()).limit(10))
q = c.dialect.compile(c, t.order_by(Random()).limit(10))
self.assertEqual(q, "SELECT * FROM (SELECT * FROM a ORDER BY random()) AS LIMITED_SELECT LIMIT 10")

q = c.compile(t.select(coalesce(this.a, this.b)))
q = c.dialect.compile(c, t.select(coalesce(this.a, this.b)))
self.assertEqual(q, "SELECT COALESCE(a, b) FROM a")

def test_select_distinct(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, distinct=True))
q = c.dialect.compile(c, t.select(this.b, distinct=True))
assert q == "SELECT DISTINCT b FROM a"

# selects merge
q = c.compile(t.where(this.b > 10).select(this.b, distinct=True))
q = c.dialect.compile(c, t.where(this.b > 10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM a WHERE (b > 10)")

# selects stay apart
q = c.compile(t.limit(10).select(this.b, distinct=True))
q = c.dialect.compile(c, t.limit(10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1")

q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
q = c.dialect.compile(c, t.select(this.b, distinct=True).select(distinct=False))
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")

def test_select_with_optimizer_hints(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a"

q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")

q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1"
)

q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3"
)
Expand All @@ -244,85 +245,86 @@ def test_table_ops(self):
a = table("a").select(this.x)
b = table("b").select(this.y)

q = c.compile(a.union(b))
q = c.dialect.compile(c, a.union(b))
assert q == "SELECT x FROM a UNION SELECT y FROM b"

q = c.compile(a.union_all(b))
q = c.dialect.compile(c, a.union_all(b))
assert q == "SELECT x FROM a UNION ALL SELECT y FROM b"

q = c.compile(a.minus(b))
q = c.dialect.compile(c, a.minus(b))
assert q == "SELECT x FROM a EXCEPT SELECT y FROM b"

q = c.compile(a.intersect(b))
q = c.dialect.compile(c, a.intersect(b))
assert q == "SELECT x FROM a INTERSECT SELECT y FROM b"

def test_ops(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b + this.c))
q = c.dialect.compile(c, t.select(this.b + this.c))
self.assertEqual(q, "SELECT (b + c) FROM a")

q = c.compile(t.select(this.b.like(this.c)))
q = c.dialect.compile(c, t.select(this.b.like(this.c)))
self.assertEqual(q, "SELECT (b LIKE c) FROM a")

q = c.compile(t.select(-this.b.sum()))
q = c.dialect.compile(c, t.select(-this.b.sum()))
self.assertEqual(q, "SELECT (-SUM(b)) FROM a")

def test_group_by(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.group_by(this.b).agg(this.c))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1")

q = c.compile(t.where(this.b > 1).group_by(this.b).agg(this.c))
q = c.dialect.compile(c, t.where(this.b > 1).group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a WHERE (b > 1) GROUP BY 1")

self.assertRaises(CompileError, c.compile, t.select(this.b).group_by(this.b))
self.assertRaises(CompileError, c.dialect.compile, c, t.select(this.b).group_by(this.b))

q = c.compile(t.select(this.b).group_by(this.b).agg())
q = c.dialect.compile(c, t.select(this.b).group_by(this.b).agg())
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp1 GROUP BY 1")

q = c.compile(t.group_by(this.b, this.c).agg(this.d, this.e))
q = c.dialect.compile(c, t.group_by(this.b, this.c).agg(this.d, this.e))
self.assertEqual(q, "SELECT b, c, d, e FROM a GROUP BY 1, 2")

# Having
q = c.compile(t.group_by(this.b).agg(this.c).having(this.b > 1))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c).having(this.b > 1))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")

q = c.compile(t.group_by(this.b).having(this.b > 1).agg(this.c))
q = c.dialect.compile(c, t.group_by(this.b).having(this.b > 1).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")

q = c.compile(t.select(this.b).group_by(this.b).agg().having(this.b > 1))
q = c.dialect.compile(c, t.select(this.b).group_by(this.b).agg().having(this.b > 1))
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp2 GROUP BY 1 HAVING (b > 1)")

# Having sum
q = c.compile(t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1))
self.assertEqual(q, "SELECT b, c, d FROM a GROUP BY 1 HAVING (SUM(b) > 1)")

# Select interaction
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1))
q = c.dialect.compile(c, t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1))
self.assertEqual(q, "SELECT (c + 1) FROM (SELECT b, c FROM (SELECT a FROM a) tmp3 GROUP BY 1) tmp4")

def test_case_when(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(when(this.b).then(this.c)))
q = c.dialect.compile(c, t.select(when(this.b).then(this.c)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c END FROM a")

q = c.compile(t.select(when(this.b).then(this.c).else_(this.d)))
q = c.dialect.compile(c, t.select(when(this.b).then(this.c).else_(this.d)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c ELSE d END FROM a")

q = c.compile(
q = c.dialect.compile(
c,
t.select(
when(this.type == "text")
.then(this.text)
.when(this.type == "number")
.then(this.number)
.else_("unknown type")
)
),
)
self.assertEqual(
q,
Expand All @@ -333,13 +335,13 @@ def test_code(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, code("<x>")).where(code("<y>")))
q = c.dialect.compile(c, t.select(this.b, code("<x>")).where(code("<y>")))
self.assertEqual(q, "SELECT b, <x> FROM a WHERE <y>")

def tablesample(t, size):
return code("{t} TABLESAMPLE BERNOULLI ({size})", t=t, size=size)

nonzero = table("points").where(this.x > 0, this.y > 0)

q = c.compile(tablesample(nonzero, 10))
q = c.dialect.compile(c, tablesample(nonzero, 10))
self.assertEqual(q, "SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10)")
Loading
Loading