Skip to content

Commit

Permalink
Merge branch 'main' into array-api2
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews authored Jun 29, 2024
2 parents a655252 + 8e225ec commit 1162137
Show file tree
Hide file tree
Showing 247 changed files with 6,831 additions and 1,861 deletions.
4 changes: 2 additions & 2 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ services:
- trino

minio:
image: bitnami/minio:2024.6.26
image: bitnami/minio:2024.6.28
environment:
MINIO_ROOT_USER: accesskey
MINIO_ROOT_PASSWORD: secretkey
Expand Down Expand Up @@ -156,7 +156,7 @@ services:
test:
- CMD-SHELL
- trino --output-format null --execute 'show schemas in hive; show schemas in memory'
image: trinodb/trino:450
image: trinodb/trino:451
ports:
- 8080:8080
networks:
Expand Down
45 changes: 40 additions & 5 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import pytest
import sqlglot as sg

import ibis
from ibis.backends.conftest import TEST_TABLES
Expand All @@ -20,7 +21,7 @@ class TestConf(BackendTest):
stateful = False
deps = ("datafusion",)
# Query 1 seems to require a bit more room here
tpch_absolute_tolerance = 0.11
tpc_absolute_tolerance = 0.11

def _load_data(self, **_: Any) -> None:
con = self.connection
Expand All @@ -39,12 +40,46 @@ def connect(*, tmpdir, worker_id, **kw):
return ibis.datafusion.connect(**kw)

def load_tpch(self) -> None:
"""Load TPC-H data."""
self.tpch_tables = frozenset(self._load_tpc(suite="h", scale_factor="0.17"))

def _load_tpc(self, *, suite, scale_factor):
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
schema = f"tpc{suite}"
con.create_database(schema)
tables = set()
for path in self.data_dir.joinpath(
schema, f"sf={scale_factor}", "parquet"
).glob("*.parquet"):
table_name = path.with_suffix("").name
con.read_parquet(path, table_name=table_name)
tables.add(table_name)
con.con.sql(
# datafusion can't create an external table in a specific schema it seems
# so hack around that by
#
# 1. creating an external table in the current schema
# 2. create an internal table in the desired schema using a
# CTAS from the external table
# 3. drop the external table
f"CREATE EXTERNAL TABLE {table_name} STORED AS PARQUET LOCATION '{path}'"
)

con.con.sql(
f"CREATE TABLE {schema}.{table_name} AS SELECT * FROM {table_name}"
)
con.con.sql(f"DROP TABLE {table_name}")
return tables

def _transform_tpch_sql(self, parsed):
def add_catalog_and_schema(node):
if isinstance(node, sg.exp.Table) and node.name in self.tpch_tables:
return node.__class__(
catalog="tpch",
**{k: v for k, v in node.args.items() if k != "catalog"},
)
return node

return parsed.transform(add_catalog_and_schema)


@pytest.fixture(scope="session")
Expand Down
54 changes: 47 additions & 7 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

import pytest
import sqlglot as sg

import ibis
from ibis.backends.conftest import TEST_TABLES
Expand Down Expand Up @@ -48,6 +49,7 @@ class TestConf(BackendTest):
deps = ("duckdb",)
stateful = False
supports_tpch = True
supports_tpcds = True
driver_supports_multiple_statements = True

def preload(self):
Expand Down Expand Up @@ -107,15 +109,53 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
kw["extension_directory"] = extension_directory
return ibis.duckdb.connect(**kw)

def load_tpch(self) -> None:
"""Load TPC-H data."""
def _load_tpc(self, *, suite, scale_factor):
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
schema = f"tpc{suite}"
con.con.execute(f"CREATE OR REPLACE SCHEMA {schema}")
parquet_dir = self.data_dir.joinpath(schema, f"sf={scale_factor}", "parquet")
assert parquet_dir.exists(), parquet_dir
tables = set()
for path in parquet_dir.glob("*.parquet"):
table_name = path.with_suffix("").name
# duckdb automatically infers the sf=0.17 as a hive partition
con.read_parquet(path, table_name=table_name, hive_partitioning=False)
tables.add(table_name)
# duckdb automatically infers the sf= as a hive partition so we
# need to disable it
con.con.execute(
f"CREATE OR REPLACE VIEW {schema}.{table_name} AS "
f"FROM read_parquet({str(path)!r}, hive_partitioning=false)"
)
return tables

def load_tpch(self) -> None:
"""Load TPC-H data."""
self.tpch_tables = frozenset(self._load_tpc(suite="h", scale_factor="0.17"))

def load_tpcds(self) -> None:
"""Load TPC-DS data."""
self.tpcds_tables = frozenset(self._load_tpc(suite="ds", scale_factor="0.2"))

def _transform_tpch_sql(self, parsed):
def add_catalog_and_schema(node):
if isinstance(node, sg.exp.Table) and node.name in self.tpch_tables:
return node.__class__(
catalog="tpch",
**{k: v for k, v in node.args.items() if k != "catalog"},
)
return node

return parsed.transform(add_catalog_and_schema)

def _transform_tpcds_sql(self, parsed):
def add_catalog_and_schema(node):
if isinstance(node, sg.exp.Table) and node.name in self.tpcds_tables:
return node.__class__(
catalog="tpcds",
**{k: v for k, v in node.args.items() if k != "catalog"},
)
return node

return parsed.transform(add_catalog_and_schema)


@pytest.fixture(scope="session")
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/exasol/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TestConf(ServiceBackendTest):
reduction_tolerance = 1e-7
stateful = True
service_name = "exasol"
supports_tpch = False
force_sort = True
deps = ("pyexasol",)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TestConf(BackendTest):
def load_tpch(self) -> None:
"""No-op, snowflake already defines these in `SNOWFLAKE_SAMPLE_DATA`."""

def _tpch_table(self, name: str):
def h(self, name: str):
name = name.upper()
t = self.connection.table(name, database="SNOWFLAKE_SAMPLE_DATA.TPCH_SF1")
return t.rename("snake_case")
Expand Down
52 changes: 16 additions & 36 deletions ibis/backends/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ class BackendTest(abc.ABC):
"Whether special handling is needed for running a multi-process pytest run."
supports_tpch: bool = False
"Child class defines a `load_tpch` method that loads the required TPC-H tables into a connection."
supports_tpcds: bool = False
"Child class defines a `load_tpcds` method that loads the required TPC-DS tables into a connection."
force_sort = False
"Sort results before comparing against reference computation."
rounding_method: Literal["away_from_zero", "half_to_even"] = "away_from_zero"
"Name of round method to use for rounding test comparisons."
driver_supports_multiple_statements: bool = False
"Whether the driver supports executing multiple statements in a single call."
tpch_absolute_tolerance: float | None = None
"Absolute tolerance for floating point comparisons with pytest.approx in TPC-H correctness tests."
tpc_absolute_tolerance: float | None = None
"Absolute tolerance for floating point comparisons with pytest.approx in TPC correctness tests."

@property
@abc.abstractmethod
Expand Down Expand Up @@ -130,6 +132,8 @@ def stateless_load(self, **kw):

if self.supports_tpch:
self.load_tpch()
if self.supports_tpcds:
self.load_tpcds()

def stateful_load(self, fn, **kw):
if not fn.exists():
Expand Down Expand Up @@ -297,42 +301,18 @@ def api(self):
def make_context(self, params: Mapping[ir.Value, Any] | None = None):
return self.api.compiler.make_context(params=params)

@property
def customer(self):
return self._tpch_table("customer")

@property
def lineitem(self):
return self._tpch_table("lineitem")

@property
def nation(self):
return self._tpch_table("nation")

@property
def orders(self):
return self._tpch_table("orders")

@property
def part(self):
return self._tpch_table("part")

@property
def partsupp(self):
return self._tpch_table("partsupp")

@property
def region(self):
return self._tpch_table("region")
def _tpc_table(self, name: str, benchmark: Literal["h", "ds"]):
if not getattr(self, f"supports_tpc{benchmark}"):
pytest.skip(
f"{self.name()} backend does not support testing TPC-{benchmark.upper()}"
)
return self.connection.table(name, database=f"tpc{benchmark}")

@property
def supplier(self):
return self._tpch_table("supplier")
def h(self, name: str) -> ir.Table:
return self._tpc_table(name, "h")

def _tpch_table(self, name: str):
if not self.supports_tpch:
pytest.skip(f"{self.name()} backend does not support testing TPC-H")
return self.connection.table(name)
def ds(self, name: str) -> ir.Table:
return self._tpc_table(name, "ds")


class ServiceBackendTest(BackendTest):
Expand Down
File renamed without changes.
122 changes: 122 additions & 0 deletions ibis/backends/tests/tpc/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import datetime
import functools
import re
from pathlib import Path
from typing import TYPE_CHECKING

import pytest
import sqlglot as sg
from dateutil.relativedelta import relativedelta

import ibis
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
from collections.abc import Callable

import ibis.expr.types as ir


def pytest_pyfunc_call(pyfuncitem):
"""Inject `backend` and `snapshot` fixtures to all TPC-DS test functions.
Defining this hook here limits its scope to the TPC-DS tests.
"""
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
result = testfunction(
**testargs, backend=funcargs["backend"], snapshot=funcargs["snapshot"]
)
assert (
result is None
), "test function should not return anything, did you mean to use assert?"
return True


def tpc_test(suite_name):
def inner(test: Callable[..., ir.Table]):
"""Decorator for TPC tests.
Automates the process of loading the SQL query from the file system and
asserting that the result of the ibis expression is equal to the expected
result of executing the raw SQL.
"""

name = f"tpc{suite_name}"

@getattr(pytest.mark, name)
@pytest.mark.usefixtures("backend", "snapshot")
@pytest.mark.xdist_group(name)
@functools.wraps(test)
def wrapper(*args, backend, snapshot, **kwargs):
backend_name = backend.name()
if not getattr(backend, f"supports_{name}"):
pytest.skip(
f"{backend_name} backend doesn't support testing {name} queries yet"
)
query_name_match = re.match(r"^test_(\d\d)$", test.__name__)
assert query_name_match is not None

query_number = query_name_match.group(1)
sql_path_name = f"{query_number}.sql"

path = Path(__file__).parent.joinpath(
"queries", "duckdb", suite_name, sql_path_name
)
raw_sql = path.read_text()

sql = sg.parse_one(raw_sql, read="duckdb")

transform_method = getattr(
backend, f"_transform_{name}_sql", lambda sql: sql
)
sql = transform_method(sql)

raw_sql = sql.sql(dialect="duckdb", pretty=True)

expected_expr = backend.connection.sql(raw_sql, dialect="duckdb")

result_expr = test(*args, **kwargs)

ibis_sql = ibis.to_sql(result_expr, dialect=backend_name)

assert result_expr._find_backend(use_default=False) is backend.connection
result = backend.connection.to_pandas(result_expr)
assert not result.empty

expected = expected_expr.to_pandas()
assert list(map(str.lower, expected.columns)) == result.columns.tolist()
expected.columns = result.columns

expected = PandasData.convert_table(expected, result_expr.schema())
assert not expected.empty

assert len(expected) == len(result)
assert result.columns.tolist() == expected.columns.tolist()
for column in result.columns:
left = result.loc[:, column]
right = expected.loc[:, column]
assert (
pytest.approx(
left.values.tolist(),
nan_ok=True,
abs=backend.tpc_absolute_tolerance,
)
== right.values.tolist()
)

# only write sql if the execution passes
snapshot.assert_match(ibis_sql, sql_path_name)

return wrapper

return inner


def add_date(datestr: str, dy: int = 0, dm: int = 0, dd: int = 0) -> ir.DateScalar:
dt = datetime.date.fromisoformat(datestr)
dt += relativedelta(years=dy, months=dm, days=dd)
return ibis.date(dt.isoformat())
Empty file.
Loading

0 comments on commit 1162137

Please sign in to comment.