Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions retriever/src/retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path
from typing import Optional, TextIO

import pandas as pd
import ray
import typer
from retriever import create_ingestor
Expand Down Expand Up @@ -102,11 +103,11 @@ def _configure_logging(log_file: Optional[Path]) -> tuple[Optional[TextIO], Text
return fh, original_stdout, original_stderr


def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]:
def _load_metadata_columns(uri: str, table_name: str) -> Optional[pd.DataFrame]:
"""
Estimate pages processed by counting unique (source_id, page_number) pairs.
Load only the metadata columns from LanceDB, skipping the large vector column.

Falls back to table row count if page-level fields are unavailable.
Returns a DataFrame with [source_id, page_number, metadata] or None on failure.
"""
try:
db = _lancedb().connect(uri)
Expand All @@ -115,15 +116,28 @@ def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]:
return None

try:
df = table.to_pandas()[["source_id", "page_number"]]
return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0])
return table.to_lance().to_table(columns=["source_id", "page_number", "metadata"]).to_pandas()
except Exception:
try:
return int(table.count_rows())
return table.to_pandas()[["source_id", "page_number", "metadata"]]
except Exception:
return None


def _estimate_processed_pages(df: Optional[pd.DataFrame]) -> Optional[int]:
"""
Estimate pages processed by counting unique (source_id, page_number) pairs.
"""
if df is None:
return None
try:
return int(
df[["source_id", "page_number"]].dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]
)
except Exception:
return None


def _to_int(value: object, default: int = 0) -> int:
try:
if value is None:
Expand All @@ -133,18 +147,14 @@ def _to_int(value: object, default: int = 0) -> int:
return default


def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]:
def _collect_detection_summary(df: Optional[pd.DataFrame]) -> Optional[dict]:
"""
Collect per-model detection totals deduped by (source_id, page_number).

Counts are read from LanceDB row `metadata`, which is populated during batch
ingestion by the Ray write stage.
"""
try:
db = _lancedb().connect(uri)
table = db.open_table(table_name)
df = table.to_pandas()[["source_id", "page_number", "metadata"]]
except Exception:
if df is None:
return None

# Deduplicate exploded rows by page key; keep max per-page counts.
Expand Down Expand Up @@ -402,7 +412,7 @@ def main(
help="Batch size for PDF extraction stage.",
),
pdf_split_batch_size: int = typer.Option(
1,
4,
"--pdf-split-batch-size",
min=1,
help="Batch size for PDF split stage.",
Expand Down Expand Up @@ -778,8 +788,9 @@ def main(
)
)
ingest_elapsed_s = time.perf_counter() - ingest_start
processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE)
detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE)
metadata_df = _load_metadata_columns(lancedb_uri, LANCEDB_TABLE)
processed_pages = _estimate_processed_pages(metadata_df)
detection_summary = _collect_detection_summary(metadata_df)
print("Extraction complete.")
_print_detection_summary(detection_summary)
if detection_summary_file is not None:
Expand Down
Loading
Loading