diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index a8696e24b..952f160dc 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -257,26 +257,13 @@ def _write_detection_summary(path: Path, summary: Optional[dict]) -> None: target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only; excludes Ray startup and recall): {pps:.2f}") - - def _ensure_lancedb_table(uri: str, table_name: str) -> None: - """ - Ensure the local LanceDB URI exists and table can be opened. + """Ensure the local LanceDB URI exists and table can be opened. Creates an empty table with the expected schema if it does not exist yet. """ - # Local path URI in this pipeline. + from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema + Path(uri).mkdir(parents=True, exist_ok=True) db = _lancedb().connect(uri) @@ -288,63 +275,11 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None: import pyarrow as pa # type: ignore - schema = pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 2048)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - ) - empty = pa.table( - { - "vector": [], - "pdf_page": [], - "filename": [], - "pdf_basename": [], - "page_number": [], - "source_id": [], - "path": [], - "text": [], - "metadata": [], - "source": [], - }, - schema=schema, - ) + schema = lancedb_schema(2048) + empty = pa.table({f.name: [] for f in schema}, schema=schema) db.create_table(table_name, data=empty, schema=schema, mode="create") -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None - return key, dist - - @app.command() def main( ctx: typer.Context, diff --git a/nemo_retriever/src/nemo_retriever/examples/common.py b/nemo_retriever/src/nemo_retriever/examples/common.py new file mode 100644 index 000000000..70a165bf6 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/examples/common.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers used by multiple example pipeline scripts.""" + +from __future__ import annotations + +from typing import Optional + + +def estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: + """Estimate pages processed by counting unique (source_id, page_number) pairs. + + Falls back to table row count if page-level fields are unavailable. + """ + try: + import lancedb # type: ignore + + db = lancedb.connect(uri) + table = db.open_table(table_name) + except Exception: + return None + + try: + df = table.to_pandas()[["source_id", "page_number"]] + return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + except Exception: + try: + return int(table.count_rows()) + except Exception: + return None + + +def print_pages_per_second( + processed_pages: Optional[int], + ingest_elapsed_s: float, + *, + label: str = "ingest only", +) -> None: + """Print a throughput summary line.""" + if ingest_elapsed_s <= 0: + print("Pages/sec: unavailable (ingest elapsed time was non-positive).") + return + if processed_pages is None: + print(f"Pages/sec: unavailable (could not estimate processed pages). Ingest time: {ingest_elapsed_s:.2f}s") + return + + pps = processed_pages / ingest_elapsed_s + print(f"Pages processed: {processed_pages}") + print(f"Pages/sec ({label}): {pps:.2f}") diff --git a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py index b378ba69f..da327ff1f 100644 --- a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py @@ -27,18 +27,20 @@ from nemo_retriever.examples.batch_pipeline import ( LANCEDB_TABLE, LANCEDB_URI, + _collect_detection_summary, _configure_logging, _ensure_lancedb_table, - _estimate_processed_pages, - _gold_to_doc_page, - _hit_key_and_distance, - _is_hit_at_k, _print_detection_summary, - _print_pages_per_second, _write_detection_summary, - _collect_detection_summary, ) -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -242,7 +244,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE) print("Extraction complete.") _print_detection_summary(detection_summary) @@ -255,7 +257,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(lancedb_uri) @@ -277,7 +279,7 @@ def main( try: if int(table.count_rows()) == 0: print(f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return except Exception: pass @@ -305,16 +307,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: print(f"\nQuery {i}: {q}") @@ -345,7 +347,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) finally: # Restore real stdio before closing the mirror file so exception hooks # and late flushes never write to a closed stream wrapper. diff --git a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py index d08a5c786..3b3768308 100644 --- a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py @@ -7,7 +7,6 @@ Run with: uv run python -m nemo_retriever.examples.inprocess_pipeline """ -import json import time from pathlib import Path from typing import Optional @@ -15,12 +14,19 @@ import lancedb import typer from nemo_retriever import create_ingestor +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -28,74 +34,6 @@ LANCEDB_TABLE = "nv-ingest" -def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: - """ - Estimate pages processed by counting unique (source_id, page_number) pairs. - - Falls back to table row count if page-level fields are unavailable. - """ - try: - db = lancedb.connect(uri) - table = db.open_table(table_name) - except Exception: - return None - - try: - df = table.to_pandas()[["source_id", "page_number"]] - return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) - except Exception: - try: - return int(table.count_rows()) - except Exception: - return None - - -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only): {pps:.2f}") - - -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -402,7 +340,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) print("Extraction complete.") # --------------------------------------------------------------------------- @@ -411,7 +349,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(f"./{LANCEDB_URI}") @@ -446,16 +384,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ( @@ -496,7 +434,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) if __name__ == "__main__": diff --git a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py index d72fea7cf..f2da17b5d 100644 --- a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py @@ -15,7 +15,6 @@ --run-mode online --base-url http://localhost:7670 """ -import json from pathlib import Path import lancedb @@ -27,7 +26,13 @@ from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -35,39 +40,6 @@ LANCEDB_TABLE = "nv-ingest" -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -236,14 +208,14 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ".txt" if input_type == "txt" else (".docx" if input_type == "doc" else ".pdf") typer.echo(f"\nQuery {i}: {q}") diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index b684a3f6d..95c540850 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -299,6 +299,37 @@ def is_hit_at_k(golden_key: str, retrieved: Sequence[str], k: int, *, match_mode return _is_hit(str(golden_key), list(retrieved), int(k), match_mode=str(match_mode)) +def gold_to_doc_page(golden_key: str) -> tuple[str, str]: + """Split a golden key like ``"docname_page"`` into ``(doc, page)``.""" + s = str(golden_key) + if "_" not in s: + return s, "" + doc, page = s.rsplit("_", 1) + return doc, page + + +def hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: + """Extract ``(pdf_page key, distance)`` from a single LanceDB hit dict. + + Supports both ``_distance`` and ``_score`` fields for compatibility across + LanceDB query types (vector vs hybrid). + """ + try: + res = json.loads(hit.get("metadata", "{}")) + source = json.loads(hit.get("source", "{}")) + except Exception: + return None, None + + source_id = source.get("source_id") + page_number = res.get("page_number") + if not source_id or page_number is None: + return None, float(hit.get("_distance")) if "_distance" in hit else None + + key = f"{Path(str(source_id)).stem}_{page_number}" + dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None + return key, dist + + def _recall_at_k(gold: List[str], retrieved: List[List[str]], k: int, *, match_mode: str) -> float: hits = sum(is_hit_at_k(g, r, k, match_mode=match_mode) for g, r in zip(gold, retrieved)) return hits / max(1, len(gold))