Skip to content
75 changes: 5 additions & 70 deletions nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,13 @@ def _write_detection_summary(path: Path, summary: Optional[dict]) -> None:
target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")


def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None:
if ingest_elapsed_s <= 0:
print("Pages/sec: unavailable (ingest elapsed time was non-positive).")
return
if processed_pages is None:
print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s")
return

pps = processed_pages / ingest_elapsed_s
print(f"Pages processed: {processed_pages}")
print(f"Pages/sec (ingest only; excludes Ray startup and recall): {pps:.2f}")


def _ensure_lancedb_table(uri: str, table_name: str) -> None:
"""
Ensure the local LanceDB URI exists and table can be opened.
"""Ensure the local LanceDB URI exists and table can be opened.

Creates an empty table with the expected schema if it does not exist yet.
"""
# Local path URI in this pipeline.
from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema

Path(uri).mkdir(parents=True, exist_ok=True)

db = _lancedb().connect(uri)
Expand All @@ -288,63 +275,11 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None:

import pyarrow as pa # type: ignore

schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2048)),
pa.field("pdf_page", pa.string()),
pa.field("filename", pa.string()),
pa.field("pdf_basename", pa.string()),
pa.field("page_number", pa.int32()),
pa.field("source_id", pa.string()),
pa.field("path", pa.string()),
pa.field("text", pa.string()),
pa.field("metadata", pa.string()),
pa.field("source", pa.string()),
]
)
empty = pa.table(
{
"vector": [],
"pdf_page": [],
"filename": [],
"pdf_basename": [],
"page_number": [],
"source_id": [],
"path": [],
"text": [],
"metadata": [],
"source": [],
},
schema=schema,
)
schema = lancedb_schema(2048)
empty = pa.table({f.name: [] for f in schema}, schema=schema)
db.create_table(table_name, data=empty, schema=schema, mode="create")


def _gold_to_doc_page(golden_key: str) -> tuple[str, str]:
s = str(golden_key)
if "_" not in s:
return s, ""
doc, page = s.rsplit("_", 1)
return doc, page


def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]:
try:
res = json.loads(hit.get("metadata", "{}"))
source = json.loads(hit.get("source", "{}"))
except Exception:
return None, None

source_id = source.get("source_id")
page_number = res.get("page_number")
if not source_id or page_number is None:
return None, float(hit.get("_distance")) if "_distance" in hit else None

key = f"{Path(str(source_id)).stem}_{page_number}"
dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None
return key, dist


@app.command()
def main(
ctx: typer.Context,
Expand Down
51 changes: 51 additions & 0 deletions nemo_retriever/src/nemo_retriever/examples/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Shared helpers used by multiple example pipeline scripts."""

from __future__ import annotations

from typing import Optional


def estimate_processed_pages(uri: str, table_name: str) -> Optional[int]:
"""Estimate pages processed by counting unique (source_id, page_number) pairs.

Falls back to table row count if page-level fields are unavailable.
"""
try:
import lancedb # type: ignore

db = lancedb.connect(uri)
table = db.open_table(table_name)
except Exception:
return None

try:
df = table.to_pandas()[["source_id", "page_number"]]
return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0])
except Exception:
try:
return int(table.count_rows())
except Exception:
return None


def print_pages_per_second(
processed_pages: Optional[int],
ingest_elapsed_s: float,
*,
label: str = "ingest only",
) -> None:
"""Print a throughput summary line."""
if ingest_elapsed_s <= 0:
print("Pages/sec: unavailable (ingest elapsed time was non-positive).")
return
if processed_pages is None:
print(f"Pages/sec: unavailable (could not estimate processed pages). Ingest time: {ingest_elapsed_s:.2f}s")
return

pps = processed_pages / ingest_elapsed_s
print(f"Pages processed: {processed_pages}")
print(f"Pages/sec ({label}): {pps:.2f}")
30 changes: 16 additions & 14 deletions nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@
from nemo_retriever.examples.batch_pipeline import (
LANCEDB_TABLE,
LANCEDB_URI,
_collect_detection_summary,
_configure_logging,
_ensure_lancedb_table,
_estimate_processed_pages,
_gold_to_doc_page,
_hit_key_and_distance,
_is_hit_at_k,
_print_detection_summary,
_print_pages_per_second,
_write_detection_summary,
_collect_detection_summary,
)
from nemo_retriever.recall.core import RecallConfig, retrieve_and_score
from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second
from nemo_retriever.recall.core import (
RecallConfig,
gold_to_doc_page,
hit_key_and_distance,
is_hit_at_k,
retrieve_and_score,
)

app = typer.Typer()

Expand Down Expand Up @@ -242,7 +244,7 @@ def main(
)
)
ingest_elapsed_s = time.perf_counter() - ingest_start
processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE)
processed_pages = estimate_processed_pages(lancedb_uri, LANCEDB_TABLE)
detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE)
print("Extraction complete.")
_print_detection_summary(detection_summary)
Expand All @@ -255,7 +257,7 @@ def main(
query_csv = Path(query_csv)
if not query_csv.exists():
print(f"Query CSV not found at {query_csv}; skipping recall evaluation.")
_print_pages_per_second(processed_pages, ingest_elapsed_s)
print_pages_per_second(processed_pages, ingest_elapsed_s)
return

db = lancedb.connect(lancedb_uri)
Expand All @@ -277,7 +279,7 @@ def main(
try:
if int(table.count_rows()) == 0:
print(f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping recall evaluation.")
_print_pages_per_second(processed_pages, ingest_elapsed_s)
print_pages_per_second(processed_pages, ingest_elapsed_s)
return
except Exception:
pass
Expand Down Expand Up @@ -305,16 +307,16 @@ def main(
_raw_hits,
)
):
doc, page = _gold_to_doc_page(g)
doc, page = gold_to_doc_page(g)

scored_hits: list[tuple[str, float | None]] = []
for h in hits:
key, dist = _hit_key_and_distance(h)
key, dist = hit_key_and_distance(h)
if key:
scored_hits.append((key, dist))

top_keys = [k for (k, _d) in scored_hits]
hit = _is_hit_at_k(g, top_keys, cfg.top_k)
hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page")

if not no_recall_details:
print(f"\nQuery {i}: {q}")
Expand Down Expand Up @@ -345,7 +347,7 @@ def main(
print("\nRecall metrics (matching nemo_retriever.recall.core):")
for k, v in metrics.items():
print(f" {k}: {v:.4f}")
_print_pages_per_second(processed_pages, ingest_elapsed_s)
print_pages_per_second(processed_pages, ingest_elapsed_s)
finally:
# Restore real stdio before closing the mirror file so exception hooks
# and late flushes never write to a closed stream wrapper.
Expand Down
90 changes: 14 additions & 76 deletions nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,95 +7,33 @@
Run with: uv run python -m nemo_retriever.examples.inprocess_pipeline <input-dir>
"""

import json
import time
from pathlib import Path
from typing import Optional

import lancedb
import typer
from nemo_retriever import create_ingestor
from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second
from nemo_retriever.params import EmbedParams
from nemo_retriever.params import ExtractParams
from nemo_retriever.params import IngestExecuteParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params import VdbUploadParams
from nemo_retriever.recall.core import RecallConfig, retrieve_and_score
from nemo_retriever.recall.core import (
RecallConfig,
gold_to_doc_page,
hit_key_and_distance,
is_hit_at_k,
retrieve_and_score,
)

app = typer.Typer()

LANCEDB_URI = "lancedb"
LANCEDB_TABLE = "nv-ingest"


def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]:
"""
Estimate pages processed by counting unique (source_id, page_number) pairs.

Falls back to table row count if page-level fields are unavailable.
"""
try:
db = lancedb.connect(uri)
table = db.open_table(table_name)
except Exception:
return None

try:
df = table.to_pandas()[["source_id", "page_number"]]
return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0])
except Exception:
try:
return int(table.count_rows())
except Exception:
return None


def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None:
if ingest_elapsed_s <= 0:
print("Pages/sec: unavailable (ingest elapsed time was non-positive).")
return
if processed_pages is None:
print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s")
return

pps = processed_pages / ingest_elapsed_s
print(f"Pages processed: {processed_pages}")
print(f"Pages/sec (ingest only): {pps:.2f}")


def _gold_to_doc_page(golden_key: str) -> tuple[str, str]:
s = str(golden_key)
if "_" not in s:
return s, ""
doc, page = s.rsplit("_", 1)
return doc, page


def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool:
doc, page = _gold_to_doc_page(golden_key)
specific_page = f"{doc}_{page}"
entire_document = f"{doc}_-1"
top = (retrieved_keys or [])[: int(k)]
return (specific_page in top) or (entire_document in top)


def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]:
try:
res = json.loads(hit.get("metadata", "{}"))
source = json.loads(hit.get("source", "{}"))
except Exception:
return None, None

source_id = source.get("source_id")
page_number = res.get("page_number")
if not source_id or page_number is None:
return None, float(hit.get("_distance")) if "_distance" in hit else None

key = f"{Path(str(source_id)).stem}_{page_number}"
dist = float(hit.get("_distance")) if "_distance" in hit else None
return key, dist


@app.command()
def main(
input_path: Path = typer.Argument(
Expand Down Expand Up @@ -402,7 +340,7 @@ def main(
)
)
ingest_elapsed_s = time.perf_counter() - ingest_start
processed_pages = _estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE)
processed_pages = estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE)
print("Extraction complete.")

# ---------------------------------------------------------------------------
Expand All @@ -411,7 +349,7 @@ def main(
query_csv = Path(query_csv)
if not query_csv.exists():
print(f"Query CSV not found at {query_csv}; skipping recall evaluation.")
_print_pages_per_second(processed_pages, ingest_elapsed_s)
print_pages_per_second(processed_pages, ingest_elapsed_s)
return

db = lancedb.connect(f"./{LANCEDB_URI}")
Expand Down Expand Up @@ -446,16 +384,16 @@ def main(
_raw_hits,
)
):
doc, page = _gold_to_doc_page(g)
doc, page = gold_to_doc_page(g)

scored_hits: list[tuple[str, float | None]] = []
for h in hits:
key, dist = _hit_key_and_distance(h)
key, dist = hit_key_and_distance(h)
if key:
scored_hits.append((key, dist))

top_keys = [k for (k, _d) in scored_hits]
hit = _is_hit_at_k(g, top_keys, cfg.top_k)
hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page")

if not no_recall_details:
ext = (
Expand Down Expand Up @@ -496,7 +434,7 @@ def main(
print("\nRecall metrics (matching nemo_retriever.recall.core):")
for k, v in metrics.items():
print(f" {k}: {v:.4f}")
_print_pages_per_second(processed_pages, ingest_elapsed_s)
print_pages_per_second(processed_pages, ingest_elapsed_s)


if __name__ == "__main__":
Expand Down
Loading
Loading