Skip to content
104 changes: 6 additions & 98 deletions nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import sys
import time
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Optional, TextIO
Expand Down Expand Up @@ -116,107 +115,16 @@ def _to_int(value: object, default: int = 0) -> int:


def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]:
"""
Collect per-model detection totals deduped by (source_id, page_number).

Counts are read from LanceDB row `metadata`, which is populated during batch
ingestion by the Ray write stage.
"""
try:
db = _lancedb().connect(uri)
table = db.open_table(table_name)
df = table.to_pandas()[["source_id", "page_number", "metadata"]]
except Exception:
return None

# Deduplicate exploded rows by page key; keep max per-page counts.
per_page: dict[tuple[str, int], dict] = {}
for row in df.itertuples(index=False):
source_id = str(getattr(row, "source_id", "") or "")
page_number = _to_int(getattr(row, "page_number", -1), default=-1)
key = (source_id, page_number)

raw_metadata = getattr(row, "metadata", None)
meta: dict = {}
if isinstance(raw_metadata, str) and raw_metadata.strip():
try:
parsed = json.loads(raw_metadata)
if isinstance(parsed, dict):
meta = parsed
except Exception:
meta = {}

entry = per_page.setdefault(
key,
{
"page_elements_total": 0,
"ocr_table_total": 0,
"ocr_chart_total": 0,
"ocr_infographic_total": 0,
"page_elements_by_label": defaultdict(int),
},
)

pe_total = _to_int(meta.get("page_elements_v3_num_detections"), default=0)
entry["page_elements_total"] = max(entry["page_elements_total"], pe_total)

ocr_table = _to_int(meta.get("ocr_table_detections"), default=0)
ocr_chart = _to_int(meta.get("ocr_chart_detections"), default=0)
ocr_infographic = _to_int(meta.get("ocr_infographic_detections"), default=0)
entry["ocr_table_total"] = max(entry["ocr_table_total"], ocr_table)
entry["ocr_chart_total"] = max(entry["ocr_chart_total"], ocr_chart)
entry["ocr_infographic_total"] = max(entry["ocr_infographic_total"], ocr_infographic)

label_counts = meta.get("page_elements_v3_counts_by_label")
if isinstance(label_counts, dict):
for label, count in label_counts.items():
if not isinstance(label, str):
continue
entry["page_elements_by_label"][label] = max(
entry["page_elements_by_label"][label],
_to_int(count, default=0),
)
"""Collect per-model detection totals deduped by (source_id, page_number)."""
from nemo_retriever.utils.detection_summary import collect_detection_summary_from_lancedb

pe_by_label_totals: dict[str, int] = defaultdict(int)
page_elements_total = 0
ocr_table_total = 0
ocr_chart_total = 0
ocr_infographic_total = 0
for page_entry in per_page.values():
page_elements_total += int(page_entry["page_elements_total"])
ocr_table_total += int(page_entry["ocr_table_total"])
ocr_chart_total += int(page_entry["ocr_chart_total"])
ocr_infographic_total += int(page_entry["ocr_infographic_total"])
for label, count in page_entry["page_elements_by_label"].items():
pe_by_label_totals[label] += int(count)

return {
"pages_seen": int(len(per_page)),
"page_elements_v3_total_detections": int(page_elements_total),
"page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())),
"ocr_table_total_detections": int(ocr_table_total),
"ocr_chart_total_detections": int(ocr_chart_total),
"ocr_infographic_total_detections": int(ocr_infographic_total),
}
return collect_detection_summary_from_lancedb(uri, table_name)


def _print_detection_summary(summary: Optional[dict]) -> None:
if summary is None:
print("Detection summary: unavailable (could not read LanceDB metadata).")
return
print("\nDetection summary (deduped by source_id/page_number):")
print(f" Pages seen: {summary['pages_seen']}")
print(f" PageElements v3 total detections: {summary['page_elements_v3_total_detections']}")
print(f" OCR table detections: {summary['ocr_table_total_detections']}")
print(f" OCR chart detections: {summary['ocr_chart_total_detections']}")
print(f" OCR infographic detections: {summary['ocr_infographic_total_detections']}")
print(" PageElements v3 counts by label:")
by_label = summary.get("page_elements_v3_counts_by_label") or {}
if not by_label:
print(" (none)")
else:
for label, count in by_label.items():
print(f" {label}: {count}")
from nemo_retriever.utils.detection_summary import print_detection_summary

print_detection_summary(summary)


def _extract_error_payloads(v: object) -> list[object]:
Expand Down
128 changes: 6 additions & 122 deletions nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,140 +911,24 @@ def _process_chunk_cpu(chunk_df: pd.DataFrame, cpu_tasks: list) -> pd.DataFrame:


def _collect_summary_from_df(df: pd.DataFrame) -> dict:
"""Compute detection summary from a result DataFrame.
"""Compute detection summary from a result DataFrame."""
from nemo_retriever.utils.detection_summary import collect_detection_summary_from_df

Mirrors the batch pipeline's ``_collect_detection_summary`` but reads
directly from the in-memory DataFrame instead of LanceDB. Rows are
deduplicated by ``(path, page_number)`` so exploded content rows don't
inflate counts.
"""
per_page: dict[tuple, dict] = {}

for _, row in df.iterrows():
row_dict = row.to_dict()

path = str(row_dict.get("path") or row_dict.get("source_id") or "")
page_number = -1
try:
page_number = int(row_dict.get("page_number", -1))
except (TypeError, ValueError):
pass

key = (path, page_number)

meta = row_dict.get("metadata")
if isinstance(meta, str):
try:
meta = json.loads(meta)
except Exception:
meta = {}
if not isinstance(meta, dict):
meta = {}

entry = per_page.setdefault(
key,
{
"pe": 0,
"ge": 0,
"ocr_table": 0,
"ocr_chart": 0,
"ocr_infographic": 0,
"pe_by_label": defaultdict(int),
},
)

# Check metadata first, then fall back to direct DataFrame columns.
# The batch pipeline stores these inside the metadata JSON, but the
# inprocess pipeline keeps them as top-level DataFrame columns.
try:
pe = int(
meta.get("page_elements_v3_num_detections") or row_dict.get("page_elements_v3_num_detections") or 0
)
except (TypeError, ValueError):
pe = 0
entry["pe"] = max(entry["pe"], pe)

try:
ge = int(
meta.get("graphic_elements_v1_num_detections")
or row_dict.get("graphic_elements_v1_num_detections")
or 0
)
except (TypeError, ValueError):
ge = 0
entry["ge"] = max(entry["ge"], ge)

for field, meta_key, col_key in [
("ocr_table", "ocr_table_detections", "table"),
("ocr_chart", "ocr_chart_detections", "chart"),
("ocr_infographic", "ocr_infographic_detections", "infographic"),
]:
try:
val = int(meta.get(meta_key, 0) or 0)
except (TypeError, ValueError):
val = 0
# Fall back to counting direct list columns (e.g. row["table"]).
if val == 0:
col_val = row_dict.get(col_key)
if isinstance(col_val, list):
val = len(col_val)
entry[field] = max(entry[field], val)

label_counts = meta.get("page_elements_v3_counts_by_label") or row_dict.get("page_elements_v3_counts_by_label")
if isinstance(label_counts, dict):
for label, count in label_counts.items():
try:
c = int(count or 0)
except (TypeError, ValueError):
c = 0
entry["pe_by_label"][str(label)] = max(entry["pe_by_label"][str(label)], c)

pe_by_label_totals: dict[str, int] = defaultdict(int)
pe_total = ge_total = ocr_table_total = ocr_chart_total = ocr_infographic_total = 0
for e in per_page.values():
pe_total += e["pe"]
ge_total += e["ge"]
ocr_table_total += e["ocr_table"]
ocr_chart_total += e["ocr_chart"]
ocr_infographic_total += e["ocr_infographic"]
for label, count in e["pe_by_label"].items():
pe_by_label_totals[label] += count

return {
"pages_seen": len(per_page),
"page_elements_v3_total_detections": pe_total,
"graphic_elements_v1_total_detections": ge_total,
"page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())),
"ocr_table_total_detections": ocr_table_total,
"ocr_chart_total_detections": ocr_chart_total,
"ocr_infographic_total_detections": ocr_infographic_total,
}
return collect_detection_summary_from_df(df)


def _print_ingest_summary(results: list, elapsed_s: float) -> None:
"""Print end-of-ingest summary matching batch pipeline output format."""
from nemo_retriever.utils.detection_summary import print_detection_summary

dfs = [r for r in results if isinstance(r, pd.DataFrame) and not r.empty]
if not dfs:
print(f"\nIngest time: {elapsed_s:.2f}s (no documents processed)")
return

combined = pd.concat(dfs, ignore_index=True) if len(dfs) > 1 else dfs[0]
summary = _collect_summary_from_df(combined)

print("\nDetection summary (deduped by source/page_number):")
print(f" Pages seen: {summary['pages_seen']}")
print(f" PageElements v3 total detections: {summary['page_elements_v3_total_detections']}")
print(f" Graphic elements v1 total detections: {summary['graphic_elements_v1_total_detections']}")
print(f" OCR table detections: {summary['ocr_table_total_detections']}")
print(f" OCR chart detections: {summary['ocr_chart_total_detections']}")
print(f" OCR infographic detections: {summary['ocr_infographic_total_detections']}")
print(" PageElements v3 counts by label:")
by_label = summary.get("page_elements_v3_counts_by_label", {})
if not by_label:
print(" (none)")
else:
for label, count in by_label.items():
print(f" {label}: {count}")
print_detection_summary(summary)

pages = summary["pages_seen"]
if elapsed_s > 0 and pages > 0:
Expand Down
11 changes: 0 additions & 11 deletions nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,11 @@ def _build_detection_metadata(row: Any) -> Dict[str, Any]:
str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None
}

ge_num = getattr(row, "graphic_elements_v1_num_detections", None)
if ge_num is not None:
try:
out["graphic_elements_v1_num_detections"] = int(ge_num)
except Exception:
pass

for ocr_col in ("table", "chart", "infographic"):
entries = getattr(row, ocr_col, None)
if isinstance(entries, list):
out[f"ocr_{ocr_col}_detections"] = int(len(entries))

ct = getattr(row, "_content_type", None)
if isinstance(ct, str) and ct:
out["content_type"] = ct

return out


Expand Down
Loading
Loading