Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion benchmark_reporting_tools/post_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _build_submission_payload(
is_official: bool,
asset_ids: list[int] | None = None,
concurrency_streams: int = 1,
validation_results: dict | None = None,
) -> dict:
"""Build a BenchmarkSubmission payload from parsed dataclasses.

Expand Down Expand Up @@ -396,10 +397,22 @@ def _query_sort_key(name: str):

query_names = sorted(raw_times.keys(), key=_query_sort_key)

per_query_validation = (validation_results or {}).get("queries", {})

def _get_validation_result(query_name):
# Look up validation result for this query (keys are lowercase e.g. "q1")
vkey = "q" + query_name.lstrip("Q").lower()
vdata = per_query_validation.get(vkey)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if per_query_validation is an empty dictionary (from line 400)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If per_query_validation is empty then per_query_validation.get(vkey) should return None for each query and the status will be returned as "not-validated".

if vdata:
return {"status": vdata["status"], "message": vdata.get("message")}
return {"status": "not-validated"}

for query_name in query_names:
times = raw_times[query_name]
is_failed = query_name in failed_queries

validation_result = _get_validation_result(query_name)

# Each execution becomes a separate query log entry
for exec_idx, runtime_ms in enumerate(times):
if is_failed:
Expand All @@ -416,11 +429,14 @@ def _query_sort_key(name: str):
"extra_info": {
"execution_number": exec_idx + 1,
},
"validation_result": validation_result,
}
)
execution_order += 1

# Handle failed queries that may not have times
# Handle failed queries that may not have times.
# Queries that failed before producing a result file are always "not-validated"
# since validate_results.py never ran for them (no parquet to compare against).
for query_name, error_info in failed_queries.items():
if query_name not in raw_times:
query_logs.append(
Expand All @@ -432,6 +448,7 @@ def _query_sort_key(name: str):
"extra_info": {
"error": str(error_info),
},
"validation_result": _get_validation_result(query_name),
}
)
execution_order += 1
Expand Down Expand Up @@ -469,6 +486,7 @@ def _query_sort_key(name: str):
"extra_info": extra_info,
"is_official": is_official,
"asset_ids": asset_ids,
"validation_status": (validation_results or {}).get("overall_status", "not-validated"),
}


Expand Down Expand Up @@ -616,6 +634,14 @@ async def _process_benchmark_dir(
print(" Warning: no config directory found. Use --config-dir to specify one.", file=sys.stderr)
engine_config = None

validation_results_path = benchmark_dir / "validation_results.json"
if validation_results_path.exists():
print(" Loading validation results...", file=sys.stderr)
validation_results = json.loads(validation_results_path.read_text())
else:
print(" No validation results found.", file=sys.stderr)
validation_results = None

# Resolve logs directory: explicit override → auto-detect from repo
effective_logs_dir = logs_dir
if effective_logs_dir is None:
Expand Down Expand Up @@ -669,6 +695,7 @@ async def _process_benchmark_dir(
is_official=is_official,
asset_ids=asset_ids,
concurrency_streams=concurrency_streams,
validation_results=validation_results,
)
except Exception as e:
print(f" Error building payload for '{bench_name}': {e}", file=sys.stderr)
Expand All @@ -680,6 +707,15 @@ async def _process_benchmark_dir(
print(f" Identifier hash: {payload['query_engine']['identifier_hash']}", file=sys.stderr)
print(f" Node count: {payload['node_count']}", file=sys.stderr)
print(f" Query logs: {len(payload['query_logs'])}", file=sys.stderr)
print(f" Validation status: {payload['validation_status']}", file=sys.stderr)
xfail_queries = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the x prefix mean here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "x" prefix was short for "expected". As in this is an "expected failure". This convention was set in the Benchmarking DB where one of the possible validation states is "XFAIL" for this particular case.

ql["query_name"]
for ql in payload["query_logs"]
if ql.get("validation_result", {}).get("status") == "expected-failure"
]
if xfail_queries:
unique_xfail = sorted(set(xfail_queries), key=lambda x: int(x))
print(f" Expected-failure queries: {unique_xfail}", file=sys.stderr)

if dry_run:
print("\n [DRY RUN] Payload:", file=sys.stderr)
Expand Down
145 changes: 5 additions & 140 deletions common/testing/integration_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

import pandas as pd

from ..result_comparison import compare_result_frames
from ..test_utils import get_abs_file_path

sys.path.append(get_abs_file_path(__file__, "../../../benchmark_data_tools"))

import duckdb
import sqlglot
from duckdb_utils import create_table


Expand All @@ -39,12 +39,14 @@ def execute_query_and_compare_results(
if request_config.getoption("--store-reference-results"):
duckdb_relation.write_parquet(f"{output_dir}/reference_results/{result_file_name}")

duckdb_rows = duckdb_relation.fetchall()
if request_config.getoption("--show-reference-result-preview"):
duckdb_rows = duckdb_relation.fetchall()
show_result_preview(duckdb_relation.columns, duckdb_rows, preview_rows_count, "Reference", query_id)

if not request_config.getoption("--skip-reference-comparison"):
compare_results(query_engine_rows, duckdb_rows, duckdb_relation.types, query, duckdb_relation.columns)
engine_df = pd.DataFrame(query_engine_rows, columns=query_engine_columns)
duckdb_df = duckdb_relation.df()
compare_result_frames(engine_df, duckdb_df, query)


def show_result_preview(columns, rows, preview_rows_count, result_source, query_id):
Expand All @@ -62,147 +64,10 @@ def write_query_engine_rows(output_dir, result_file_name, rows, columns, query_e
df.to_parquet(f"{output_dir}/{query_engine}_results/{result_file_name}")


def get_is_sorted_query(query):
return any(isinstance(expr, sqlglot.exp.Order) for expr in sqlglot.parse_one(query).iter_expressions())


def none_safe_sort_key(row):
"""Sort key that treats None as less than any other value."""
return tuple((0, x) if x is not None else (1, None) for x in row)


def compare_results(query_engine_rows, duckdb_rows, types, query, column_names):
row_count = len(query_engine_rows)
assert row_count == len(duckdb_rows)

duckdb_rows = normalize_rows(duckdb_rows, types)
query_engine_rows = normalize_rows(query_engine_rows, types)

# We need a full sort for all non-ORDER BY columns because some ORDER BY comparison
# will be equal and the resulting order of non-ORDER BY columns will be ambiguous.
sorted_duckdb_rows = sorted(duckdb_rows, key=none_safe_sort_key)
sorted_query_engine_rows = sorted(query_engine_rows, key=none_safe_sort_key)
assert_rows_equal(sorted_query_engine_rows, sorted_duckdb_rows, types)

# If we have an ORDER BY clause we want to test that the resulting order of those
# columns is correct, in addition to overall values being correct.
# However, we can only validate the ORDER BY if we can extract column indices.
# For complex ORDER BY expressions (aggregates, CASE statements, etc.), we skip
# the ORDER BY validation but still validate overall result correctness.
order_indices = get_orderby_indices(query, column_names)
if order_indices:
# Project both results to ORDER BY columns and compare in original order
duckdb_proj = [[row[i] for i in order_indices] for row in duckdb_rows]
query_engine_proj = [[row[i] for i in order_indices] for row in query_engine_rows]
projected_types = [types[i] for i in order_indices]
assert_rows_equal(query_engine_proj, duckdb_proj, projected_types)


def get_orderby_indices(query, column_names):
expr = sqlglot.parse_one(query)
order = next((e for e in expr.find_all(sqlglot.exp.Order)), None)
if not order:
return []

indices = []
for ordered in order.expressions:
key = ordered.this

# Handle numeric literals (e.g., ORDER BY 1, 2)
if isinstance(key, sqlglot.exp.Literal):
try:
col_num = int(key.this)
if 1 <= col_num <= len(column_names):
indices.append(col_num - 1) # Convert to 0-based index
continue
except (ValueError, TypeError):
pass

# Handle simple column references
if isinstance(key, sqlglot.exp.Column):
name = key.name
if name in column_names:
indices.append(column_names.index(name))
continue

# For complex expressions (CASE, SUM, etc.), skip ORDER BY validation
# We still validate overall result correctness with full sorting
# Just don't validate the specific ORDER BY column ordering
pass

return indices


def create_duckdb_table(table_name, data_path):
create_table(table_name, get_abs_file_path(__file__, data_path))


def normalize_rows(rows, types):
return [normalize_row(row, types) for row in rows]


FLOATING_POINT_TYPES = ("double", "float", "decimal")


def normalize_row(row, types):
normalized_row = []
for index, value in enumerate(row):
if value is None:
normalized_row.append(value)
continue

type_id = types[index].id
if type_id == "date":
normalized_row.append(str(value))
elif type_id in FLOATING_POINT_TYPES:
normalized_row.append(float(value))
else:
normalized_row.append(value)
return normalized_row


def assert_rows_equal(rows_1, rows_2, types):
if len(rows_1) != len(rows_2):
raise AssertionError(f"Row count mismatch: {len(rows_1)} vs {len(rows_2)}")

float_cols = {i for i, t in enumerate(types) if t.id in FLOATING_POINT_TYPES}
mismatches = []
abs_tolerance = 0.02
max_mismatches = 5

for row_idx, (row_1, row_2) in enumerate(zip(rows_1, rows_2)):
if len(row_1) != len(row_2):
mismatches.append(f"Row: {row_idx} length mismatch: {len(row_1)} vs {len(row_2)}")
if len(mismatches) >= max_mismatches:
break
continue

for col_idx, (value_1, value_2) in enumerate(zip(row_1, row_2)):
if value_1 is None and value_2 is None:
continue
if value_1 is None or value_2 is None:
mismatches.append(f"Row: {row_idx}, Column: {col_idx}: {value_1} vs {value_2} (null mismatch)")
elif col_idx in float_cols:
if abs(value_1 - value_2) > abs_tolerance:
mismatches.append(
f"Row: {row_idx}, Column: {col_idx}: {value_1} vs {value_2} "
f"(diff={abs(value_1 - value_2):.6f}, tolerance={abs_tolerance})"
)
elif value_1 != value_2:
mismatches.append(f"Row: {row_idx}, Column: {col_idx}: {value_1} vs {value_2}")

if len(mismatches) >= max_mismatches:
break

if len(mismatches) >= max_mismatches:
break

if mismatches:
truncated_msg = f" (showing first {max_mismatches})" if len(mismatches) >= max_mismatches else ""
mismatch_details = "\n ".join(mismatches)
raise AssertionError(f"Found {len(mismatches)} mismatches{truncated_msg}:\n {mismatch_details}")


def initialize_output_dir(config, query_engine):
output_dir = Path(config.getoption("--output-dir"))
user_reference_results_dir = config.getoption("--reference-results-dir")
Expand Down
3 changes: 3 additions & 0 deletions common/testing/requirements.txt
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

common/testing is not a Python project, so I don't think requirements.txt should be here. The dependencies should probably be managed by the project that uses the shared modules/files.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pandas>=2.0
pyarrow>=10.0
sqlglot
Loading
Loading