Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ class Compiler(AbstractCompiler):
def dialect(self) -> "BaseDialect":
return self.database.dialect

# TODO: DEPRECATED: Remove once the dialect is used directly in all places.
def compile(self, elem, params=None) -> str:
return self.dialect.compile(self, elem, params)

def new_unique_name(self, prefix="tmp") -> str:
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"
Expand Down
78 changes: 39 additions & 39 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def test_basic(self):

t = table("point")
t2 = t.select(x=this.x + 1, y=t["y"] + this.x)
assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point"
assert c.dialect.compile(c, t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point"

t = table("point").where(this.x == 1, this.y == 2)
assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)"
assert c.dialect.compile(c, t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)"

t = table("person").where(this.name == "Albert")
self.assertEqual(c.compile(t), "SELECT * FROM person WHERE (name = 'Albert')")
self.assertEqual(c.dialect.compile(c, t), "SELECT * FROM person WHERE (name = 'Albert')")

def test_outerjoin(self):
c = Compiler(MockDatabase())
Expand All @@ -127,7 +127,7 @@ def test_outerjoin(self):
j = outerjoin(a, b).on(a[k] == b[k] for k in keys)

self.assertEqual(
c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)"
c.dialect.compile(c, j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)"
)

def test_schema(self):
Expand All @@ -137,7 +137,7 @@ def test_schema(self):
# test table
t = table("a", schema=CaseInsensitiveDict(schema))
q = t.select(this.Id, t["COMMENT"])
assert c.compile(q) == "SELECT id, comment FROM a"
assert c.dialect.compile(c, q) == "SELECT id, comment FROM a"

t = table("a", schema=CaseSensitiveDict(schema))
self.assertRaises(KeyError, t.__getitem__, "Id")
Expand Down Expand Up @@ -174,67 +174,67 @@ def test_cte(self):
t3 = t2.select(this.x)

expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
assert normalize_spaces(c.dialect.compile(c, t3)) == expected

# nested cte
c = Compiler(MockDatabase())
t4 = cte(t3).select(this.x)

expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2"
assert normalize_spaces(c.compile(t4)) == expected
assert normalize_spaces(c.dialect.compile(c, t4)) == expected

# parameterized cte
c = Compiler(MockDatabase())
t2 = cte(t.select(this.x), params=["y"])
t3 = t2.select(this.y)

expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
assert normalize_spaces(c.dialect.compile(c, t3)) == expected

def test_funcs(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.order_by(Random()).limit(10))
q = c.dialect.compile(c, t.order_by(Random()).limit(10))
self.assertEqual(q, "SELECT * FROM (SELECT * FROM a ORDER BY random()) AS LIMITED_SELECT LIMIT 10")

q = c.compile(t.select(coalesce(this.a, this.b)))
q = c.dialect.compile(c, t.select(coalesce(this.a, this.b)))
self.assertEqual(q, "SELECT COALESCE(a, b) FROM a")

def test_select_distinct(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, distinct=True))
q = c.dialect.compile(c, t.select(this.b, distinct=True))
assert q == "SELECT DISTINCT b FROM a"

# selects merge
q = c.compile(t.where(this.b > 10).select(this.b, distinct=True))
q = c.dialect.compile(c, t.where(this.b > 10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM a WHERE (b > 10)")

# selects stay apart
q = c.compile(t.limit(10).select(this.b, distinct=True))
q = c.dialect.compile(c, t.limit(10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1")

q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
q = c.dialect.compile(c, t.select(this.b, distinct=True).select(distinct=False))
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")

def test_select_with_optimizer_hints(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a"

q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")

q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1"
)

q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
q = c.dialect.compile(c, t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3"
)
Expand All @@ -244,78 +244,78 @@ def test_table_ops(self):
a = table("a").select(this.x)
b = table("b").select(this.y)

q = c.compile(a.union(b))
q = c.dialect.compile(c, a.union(b))
assert q == "SELECT x FROM a UNION SELECT y FROM b"

q = c.compile(a.union_all(b))
q = c.dialect.compile(c, a.union_all(b))
assert q == "SELECT x FROM a UNION ALL SELECT y FROM b"

q = c.compile(a.minus(b))
q = c.dialect.compile(c, a.minus(b))
assert q == "SELECT x FROM a EXCEPT SELECT y FROM b"

q = c.compile(a.intersect(b))
q = c.dialect.compile(c, a.intersect(b))
assert q == "SELECT x FROM a INTERSECT SELECT y FROM b"

def test_ops(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b + this.c))
q = c.dialect.compile(c, t.select(this.b + this.c))
self.assertEqual(q, "SELECT (b + c) FROM a")

q = c.compile(t.select(this.b.like(this.c)))
q = c.dialect.compile(c, t.select(this.b.like(this.c)))
self.assertEqual(q, "SELECT (b LIKE c) FROM a")

q = c.compile(t.select(-this.b.sum()))
q = c.dialect.compile(c, t.select(-this.b.sum()))
self.assertEqual(q, "SELECT (-SUM(b)) FROM a")

def test_group_by(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.group_by(this.b).agg(this.c))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1")

q = c.compile(t.where(this.b > 1).group_by(this.b).agg(this.c))
q = c.dialect.compile(c, t.where(this.b > 1).group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a WHERE (b > 1) GROUP BY 1")

self.assertRaises(CompileError, c.compile, t.select(this.b).group_by(this.b))

q = c.compile(t.select(this.b).group_by(this.b).agg())
q = c.dialect.compile(c, t.select(this.b).group_by(this.b).agg())
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp1 GROUP BY 1")

q = c.compile(t.group_by(this.b, this.c).agg(this.d, this.e))
q = c.dialect.compile(c, t.group_by(this.b, this.c).agg(this.d, this.e))
self.assertEqual(q, "SELECT b, c, d, e FROM a GROUP BY 1, 2")

# Having
q = c.compile(t.group_by(this.b).agg(this.c).having(this.b > 1))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c).having(this.b > 1))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")

q = c.compile(t.group_by(this.b).having(this.b > 1).agg(this.c))
q = c.dialect.compile(c, t.group_by(this.b).having(this.b > 1).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")

q = c.compile(t.select(this.b).group_by(this.b).agg().having(this.b > 1))
q = c.dialect.compile(c, t.select(this.b).group_by(this.b).agg().having(this.b > 1))
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp2 GROUP BY 1 HAVING (b > 1)")

# Having sum
q = c.compile(t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1))
q = c.dialect.compile(c, t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1))
self.assertEqual(q, "SELECT b, c, d FROM a GROUP BY 1 HAVING (SUM(b) > 1)")

# Select interaction
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1))
q = c.dialect.compile(c, t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1))
self.assertEqual(q, "SELECT (c + 1) FROM (SELECT b, c FROM (SELECT a FROM a) tmp3 GROUP BY 1) tmp4")

def test_case_when(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(when(this.b).then(this.c)))
q = c.dialect.compile(c, t.select(when(this.b).then(this.c)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c END FROM a")

q = c.compile(t.select(when(this.b).then(this.c).else_(this.d)))
q = c.dialect.compile(c, t.select(when(this.b).then(this.c).else_(this.d)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c ELSE d END FROM a")

q = c.compile(
q = c.dialect.compile(c,
t.select(
when(this.type == "text")
.then(this.text)
Expand All @@ -333,13 +333,13 @@ def test_code(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.select(this.b, code("<x>")).where(code("<y>")))
q = c.dialect.compile(c, t.select(this.b, code("<x>")).where(code("<y>")))
self.assertEqual(q, "SELECT b, <x> FROM a WHERE <y>")

def tablesample(t, size):
return code("{t} TABLESAMPLE BERNOULLI ({size})", t=t, size=size)

nonzero = table("points").where(this.x > 0, this.y > 0)

q = c.compile(tablesample(nonzero, 10))
q = c.dialect.compile(c, tablesample(nonzero, 10))
self.assertEqual(q, "SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10)")
20 changes: 10 additions & 10 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def setUp(self):
self.compiler = Compiler(self.mysql)

def test_compile_string(self):
self.assertEqual("SELECT 1", self.compiler.compile(Code("SELECT 1")))
self.assertEqual("SELECT 1", self.compiler.dialect.compile(self.compiler, Code("SELECT 1")))

def test_compile_int(self):
self.assertEqual("1", self.compiler.compile(1))
self.assertEqual("1", self.compiler.dialect.compile(self.compiler, 1))

def test_compile_table_name(self):
compiler = attrs.evolve(self.compiler, root=False)
Expand All @@ -27,7 +27,7 @@ def test_compile_select(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus`"
self.assertEqual(
expected_sql,
self.compiler.compile(
self.compiler.dialect.compile(self.compiler,
Select(
table("marine_mammals", "walrus"),
[Code("name")],
Expand All @@ -39,7 +39,7 @@ def test_compile_select(self):
# expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp"
# self.assertEqual(
# expected_sql,
# self.compiler.compile(
# self.compiler.dialect.compile(self.compiler,
# Enum(
# ("walrus",),
# "id",
Expand All @@ -51,7 +51,7 @@ def test_compile_select(self):
# expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`"
# self.assertEqual(
# expected_sql,
# self.compiler.compile(
# self.compiler.dialect.compile(self.compiler,
# Select(
# ["name", Checksum(["id", "timestamp"])],
# TableName(("marine_mammals", "walrus")),
Expand All @@ -63,7 +63,7 @@ def test_compare(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id <= 1000) AND (id > 1)"
self.assertEqual(
expected_sql,
self.compiler.compile(
self.compiler.dialect.compile(self.compiler,
Select(
table("marine_mammals", "walrus"),
[Code("name")],
Expand All @@ -76,7 +76,7 @@ def test_in(self):
expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
self.compiler.dialect.compile(self.compiler,
Select(table("marine_mammals", "walrus"), [Code("name")], [In(Code("id"), [1, 2, 3])])
),
)
Expand All @@ -85,14 +85,14 @@ def test_count(self):
expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])),
self.compiler.dialect.compile(self.compiler, Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])),
)

def test_count_with_column(self):
expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
self.compiler.dialect.compile(self.compiler,
Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])
),
)
Expand All @@ -101,7 +101,7 @@ def test_explain(self):
expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(
expected_sql,
self.compiler.compile(
self.compiler.dialect.compile(self.compiler,
Explain(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])]))
),
)