From e69acd40abf1703a9dc17f653a7546854db6866f Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 17:22:50 -0500 Subject: [PATCH 1/6] Extract shared embedder factory into model/__init__.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Centralises the resolve → branch → construct pattern for local HF embedding models (VL and non-VL) that was duplicated across batch, inprocess, fused, gpu_pool, recall, retriever, and text_embed code paths into a single `create_local_embedder` factory function. Made-with: Cursor --- .../src/nemo_retriever/ingest_modes/batch.py | 41 ++----- .../src/nemo_retriever/ingest_modes/fused.py | 23 ++-- .../nemo_retriever/ingest_modes/gpu_pool.py | 21 +--- .../nemo_retriever/ingest_modes/inprocess.py | 33 ++---- .../src/nemo_retriever/model/__init__.py | 40 +++++++ .../src/nemo_retriever/recall/core.py | 15 +-- .../src/nemo_retriever/retriever.py | 21 +--- .../nemo_retriever/text_embed/processor.py | 9 +- .../nemo_retriever/text_embed/text_embed.py | 28 ++--- .../tests/test_create_local_embedder.py | 110 ++++++++++++++++++ 10 files changed, 195 insertions(+), 146 deletions(-) create mode 100644 nemo_retriever/tests/test_create_local_embedder.py diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index a265572ed..2b6f6d262 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -266,38 +266,15 @@ def __init__(self, params: EmbedParams) -> None: self._model = None return - device = self._kwargs.get("device") - hf_cache_dir = self._kwargs.get("hf_cache_dir") - normalize = bool(self._kwargs.get("normalize", True)) - max_length = int(self._kwargs.get("max_length", 8192)) - model_name_raw = self._kwargs.get("model_name") - - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model - - model_id = resolve_embed_model(model_name_raw) - - if is_vl_embed_model(model_name_raw): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - self._model = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - model_id=model_id, - ) - else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) - - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + from nemo_retriever.model import create_local_embedder + + self._model = create_local_embedder( + self._kwargs.get("model_name"), + device=str(self._kwargs["device"]) if self._kwargs.get("device") else None, + hf_cache_dir=str(self._kwargs["hf_cache_dir"]) if self._kwargs.get("hf_cache_dir") else None, + normalize=bool(self._kwargs.get("normalize", True)), + max_length=int(self._kwargs.get("max_length", 8192)), + ) def __call__(self, batch_df: Any) -> Any: from nemo_retriever.ingest_modes.inprocess import embed_text_main_text_embed diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py index 842053bd0..cd4571aa9 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/fused.py @@ -55,10 +55,8 @@ class _FusedModelActor: def __init__(self, **kwargs: Any) -> None: _assert_no_remote_endpoints(dict(kwargs), context="actor init") + from nemo_retriever.model import create_local_embedder from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3 - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) self._detect_kwargs = { "inference_batch_size": int(kwargs.get("inference_batch_size", 8)), @@ -89,13 +87,6 @@ def __init__(self, **kwargs: Any) -> None: "has_embedding_column": str(kwargs.get("has_embedding_column", "text_embeddings_1b_v2_has_embedding")), } - device = kwargs.get("device") - hf_cache_dir = kwargs.get("hf_cache_dir") - normalize = bool(kwargs.get("normalize", True)) - max_length = int(kwargs.get("max_length", 8192)) - model_name_raw = kwargs.get("model_name") - model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None - self._page_elements_model = NemotronPageElementsV3() self._ocr_model = NemotronOCRV1() self._table_structure_model = None @@ -103,12 +94,12 @@ def __init__(self, **kwargs: Any) -> None: from nemo_retriever.model.local import NemotronTableStructureV1 self._table_structure_model = NemotronTableStructureV1() - self._embed_model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, + self._embed_model = create_local_embedder( + kwargs.get("model_name"), + device=str(kwargs["device"]) if kwargs.get("device") else None, + hf_cache_dir=str(kwargs["hf_cache_dir"]) if kwargs.get("hf_cache_dir") else None, + normalize=bool(kwargs.get("normalize", True)), + max_length=int(kwargs.get("max_length", 8192)), ) def __call__(self, batch_df: Any) -> Any: diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py b/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py index 775ded97a..cb1aa019a 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/gpu_pool.py @@ -77,29 +77,14 @@ class EmbeddingModelConfig: model_id: Optional[str] = None def create(self) -> Any: - from nemo_retriever.model import is_vl_embed_model + from nemo_retriever.model import create_local_embedder - if is_vl_embed_model(self.model_id): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - return LlamaNemotronEmbedVL1BV2Embedder( - device=self.device, - hf_cache_dir=self.hf_cache_dir, - model_id=self.model_id, - ) - - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( - LlamaNemotronEmbed1BV2Embedder, - ) - - return LlamaNemotronEmbed1BV2Embedder( + return create_local_embedder( + self.model_id, device=self.device, hf_cache_dir=self.hf_cache_dir, normalize=self.normalize, max_length=self.max_length, - model_id=self.model_id, ) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 087c10749..0e09dd787 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -29,7 +29,6 @@ import pandas as pd from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3, NemotronParseV12 -from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder from nemo_retriever.page_elements import detect_page_elements_v3 from nemo_retriever.ocr.ocr import _crop_b64_image_by_norm_bbox, nemotron_parse_page_elements, ocr_page_elements from nemo_retriever.table.table_detection import table_structure_ocr_page_elements @@ -1507,38 +1506,22 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI return self # Local HF embedder path. - # Allow callers to control device / max_length to avoid OOMs. device = embed_kwargs.pop("device", None) hf_cache_dir = embed_kwargs.pop("hf_cache_dir", None) normalize = bool(embed_kwargs.pop("normalize", True)) max_length = int(embed_kwargs.pop("max_length", 8192)) - model_name_raw = embed_kwargs.pop("model_name", None) - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model - - model_id = resolve_embed_model(model_name_raw) + from nemo_retriever.model import create_local_embedder embed_kwargs.setdefault("input_type", "passage") - - if is_vl_embed_model(model_name_raw): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - embed_kwargs["model"] = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - model_id=model_id, - ) - else: - embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + embed_kwargs["model"] = create_local_embedder( + model_name_raw, + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=max_length, + ) self._tasks.append((embed_text_main_text_embed, embed_kwargs)) return self diff --git a/nemo_retriever/src/nemo_retriever/model/__init__.py b/nemo_retriever/src/nemo_retriever/model/__init__.py index dc763d548..cef002494 100644 --- a/nemo_retriever/src/nemo_retriever/model/__init__.py +++ b/nemo_retriever/src/nemo_retriever/model/__init__.py @@ -33,3 +33,43 @@ def resolve_embed_model(model_name: str | None) -> str: def is_vl_embed_model(model_name: str | None) -> bool: """Return True if *model_name* refers to the VL embedding model.""" return resolve_embed_model(model_name) in _VL_EMBED_MODEL_IDS + + +def create_local_embedder( + model_name: str | None = None, + *, + device: str | None = None, + hf_cache_dir: str | None = None, + normalize: bool = True, + max_length: int = 8192, +): + """Create the appropriate local embedding model (VL or non-VL). + + Centralises the resolve -> branch -> construct pattern that was previously + duplicated across batch, inprocess, fused, gpu_pool, recall, retriever, + and text_embed code paths. + """ + model_id = resolve_embed_model(model_name) + + if is_vl_embed_model(model_name): + from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( + LlamaNemotronEmbedVL1BV2Embedder, + ) + + return LlamaNemotronEmbedVL1BV2Embedder( + device=device, + hf_cache_dir=hf_cache_dir, + model_id=model_id, + ) + + from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import ( + LlamaNemotronEmbed1BV2Embedder, + ) + + return LlamaNemotronEmbed1BV2Embedder( + device=device, + hf_cache_dir=hf_cache_dir, + normalize=normalize, + max_length=max_length, + model_id=model_id, + ) diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index f5dbe1e68..b684a3f6d 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -168,25 +168,14 @@ def _embed_queries_local_hf( batch_size: int, model_name: Optional[str] = None, ) -> List[List[float]]: - # Lazy import: only load torch/HF when needed. - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model + from nemo_retriever.model import create_local_embedder, is_vl_embed_model - model_id = resolve_embed_model(model_name) + embedder = create_local_embedder(model_name, device=device, hf_cache_dir=cache_dir) if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder - - embedder = LlamaNemotronEmbedVL1BV2Embedder(device=device, hf_cache_dir=cache_dir, model_id=model_id) - # VL model handles query formatting internally via encode_queries(). vecs = embedder.embed_queries(queries, batch_size=int(batch_size)) else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - embedder = LlamaNemotronEmbed1BV2Embedder( - device=device, hf_cache_dir=cache_dir, normalize=True, model_id=model_id - ) vecs = embedder.embed(["query: " + q for q in queries], batch_size=int(batch_size)) - # Ensure list-of-list floats. return vecs.detach().to("cpu").tolist() diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 28ab35d4b..e018aa426 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -66,31 +66,14 @@ def _embed_queries_nim( return out def _embed_queries_local_hf(self, query_texts: list[str], *, model_name: str) -> list[list[float]]: - from nemo_retriever.model import is_vl_embed_model, resolve_embed_model + from nemo_retriever.model import create_local_embedder, is_vl_embed_model - model_id = resolve_embed_model(model_name) cache_dir = str(self.local_hf_cache_dir) if self.local_hf_cache_dir else None + embedder = create_local_embedder(model_name, device=self.local_hf_device, hf_cache_dir=cache_dir) if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - embedder = LlamaNemotronEmbedVL1BV2Embedder( - device=self.local_hf_device, - hf_cache_dir=cache_dir, - model_id=model_id, - ) vectors = embedder.embed_queries(query_texts, batch_size=int(self.local_hf_batch_size)) else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - embedder = LlamaNemotronEmbed1BV2Embedder( - device=self.local_hf_device, - hf_cache_dir=cache_dir, - normalize=True, - model_id=model_id, - ) vectors = embedder.embed(["query: " + q for q in query_texts], batch_size=int(self.local_hf_batch_size)) return vectors.detach().to("cpu").tolist() diff --git a/nemo_retriever/src/nemo_retriever/text_embed/processor.py b/nemo_retriever/src/nemo_retriever/text_embed/processor.py index 555f7e28b..81dd4b8a6 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/processor.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/processor.py @@ -89,14 +89,17 @@ def maybe_inject_local_hf_embedder(task_config: Dict[str, Any], transform_config if has_endpoint or not use_local: return - # Lazy import: only load torch/HF when we truly need local embeddings. - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + from nemo_retriever.model import create_local_embedder local_device = task_config.get("local_hf_device") local_cache_dir = task_config.get("local_hf_cache_dir") local_batch_size = int(task_config.get("local_hf_batch_size") or 64) - embedder = LlamaNemotronEmbed1BV2Embedder(device=local_device, hf_cache_dir=local_cache_dir, normalize=True) + embedder = create_local_embedder( + task_config.get("model_name"), + device=local_device, + hf_cache_dir=local_cache_dir, + ) def _embed(texts): prefix = f"{transform_config.input_type}: " if getattr(transform_config, "input_type", None) else "" diff --git a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py index 6d5d30610..5889d7627 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py @@ -190,28 +190,16 @@ def __init__(self, **detect_kwargs: Any) -> None: hf_cache_dir = self.detect_kwargs.pop("hf_cache_dir", None) normalize = bool(self.detect_kwargs.pop("normalize", True)) max_length = self.detect_kwargs.pop("max_length", 4096) - model_name = self.detect_kwargs.get("model_name") - from nemo_retriever.model import is_vl_embed_model + from nemo_retriever.model import create_local_embedder - if is_vl_embed_model(model_name): - from nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import ( - LlamaNemotronEmbedVL1BV2Embedder, - ) - - self._model = LlamaNemotronEmbedVL1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - ) - else: - from nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=int(max_length), - ) + self._model = create_local_embedder( + self.detect_kwargs.get("model_name"), + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=int(max_length), + ) def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: diff --git a/nemo_retriever/tests/test_create_local_embedder.py b/nemo_retriever/tests/test_create_local_embedder.py new file mode 100644 index 000000000..6ba3fb9c5 --- /dev/null +++ b/nemo_retriever/tests/test_create_local_embedder.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.model.create_local_embedder factory.""" + +import sys +from types import ModuleType +from unittest.mock import MagicMock + +import pytest + +from nemo_retriever.model import create_local_embedder + + +@pytest.fixture(autouse=True) +def _patch_embedders(monkeypatch): + """Prevent real model downloads by stubbing both embedder classes. + + The ``nemo_retriever.model.local`` package uses a custom ``__getattr__`` + that only exposes specific class names — not submodule names. Because + ``monkeypatch.setattr`` resolves each path segment via ``getattr``, it + cannot traverse to the submodule. We work around this by injecting fake + modules directly into ``sys.modules``, which Python checks first when + handling ``from … import`` statements. + """ + fake_text = MagicMock(name="LlamaNemotronEmbed1BV2Embedder") + fake_vl = MagicMock(name="LlamaNemotronEmbedVL1BV2Embedder") + + text_mod = ModuleType("nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder") + text_mod.LlamaNemotronEmbed1BV2Embedder = fake_text + + vl_mod = ModuleType("nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder") + vl_mod.LlamaNemotronEmbedVL1BV2Embedder = fake_vl + + monkeypatch.setitem(sys.modules, "nemo_retriever.model.local.llama_nemotron_embed_1b_v2_embedder", text_mod) + monkeypatch.setitem(sys.modules, "nemo_retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder", vl_mod) + + yield fake_text, fake_vl + + +def test_default_returns_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder() + fake_text.assert_called_once() + assert result is fake_text.return_value + + +def test_none_model_name_returns_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder(None) + fake_text.assert_called_once() + assert result is fake_text.return_value + + +def test_alias_resolved_to_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + result = create_local_embedder("nemo_retriever_v1") + call_kwargs = fake_text.call_args + assert call_kwargs.kwargs["model_id"] == "nvidia/llama-3.2-nv-embedqa-1b-v2" + assert result is fake_text.return_value + + +def test_vl_model_returns_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + result = create_local_embedder("nvidia/llama-nemotron-embed-vl-1b-v2") + fake_vl.assert_called_once() + assert result is fake_vl.return_value + + +def test_vl_short_alias_returns_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + result = create_local_embedder("llama-nemotron-embed-vl-1b-v2") + fake_vl.assert_called_once() + assert result is fake_vl.return_value + + +def test_kwargs_forwarded_to_text_embedder(_patch_embedders): + fake_text, _ = _patch_embedders + create_local_embedder( + device="cuda:1", + hf_cache_dir="/tmp/cache", + normalize=False, + max_length=4096, + ) + kw = fake_text.call_args.kwargs + assert kw["device"] == "cuda:1" + assert kw["hf_cache_dir"] == "/tmp/cache" + assert kw["normalize"] is False + assert kw["max_length"] == 4096 + + +def test_kwargs_forwarded_to_vl_embedder(_patch_embedders): + _, fake_vl = _patch_embedders + create_local_embedder( + "nvidia/llama-nemotron-embed-vl-1b-v2", + device="cuda:0", + hf_cache_dir="/models", + ) + kw = fake_vl.call_args.kwargs + assert kw["device"] == "cuda:0" + assert kw["hf_cache_dir"] == "/models" + assert kw["model_id"] == "nvidia/llama-nemotron-embed-vl-1b-v2" + + +def test_unknown_model_passes_through(_patch_embedders): + fake_text, _ = _patch_embedders + create_local_embedder("custom-org/my-embed-model") + kw = fake_text.call_args.kwargs + assert kw["model_id"] == "custom-org/my-embed-model" From 7fe8d21c3b22ce99cb16c07a082cbbfa1385cefa Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 17:24:44 -0500 Subject: [PATCH 2/6] Consolidate LanceDB row construction, schema, and table creation Extracts duplicated LanceDB row-building, schema definition, and table-creation logic from batch.py and inprocess.py into a shared ingest_modes/lancedb_utils.py module. Made-with: Cursor --- .../src/nemo_retriever/ingest_modes/batch.py | 127 ++-------- .../nemo_retriever/ingest_modes/inprocess.py | 169 ++----------- .../ingest_modes/lancedb_utils.py | 226 ++++++++++++++++++ nemo_retriever/tests/test_lancedb_utils.py | 194 +++++++++++++++ 4 files changed, 452 insertions(+), 264 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py create mode 100644 nemo_retriever/tests/test_lancedb_utils.py diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 2b6f6d262..b9b697fdd 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -86,11 +86,8 @@ class _LanceDBWriteActor: """ def __init__(self, params: VdbUploadParams | None = None) -> None: - import json - from pathlib import Path + from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema - self._json = json - self._Path = Path lancedb_params = (params or VdbUploadParams()).lancedb self._lancedb_uri = lancedb_params.lancedb_uri @@ -102,30 +99,13 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: self._text_column = lancedb_params.text_column import lancedb # type: ignore - import pyarrow as pa # type: ignore - self._pa = pa self._db = lancedb.connect(uri=self._lancedb_uri) - self._table = None - self._schema = None - self._first_batch = True self._total_rows = 0 - self._table = None - mode = "overwrite" if self._overwrite else "create" - fields = [ - pa.field("vector", pa.list_(pa.float32(), 2048)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - self._schema = pa.schema(fields) + # Use a default dim for the initial empty table; rows are appended via add(). + self._schema = lancedb_schema(2048) + mode = "overwrite" if self._overwrite else "create" self._table = self._db.create_table( self._table_name, schema=self._schema, @@ -133,95 +113,16 @@ def __init__(self, params: VdbUploadParams | None = None) -> None: ) def _build_rows(self, df: Any) -> list: - """Build LanceDB rows from a pandas DataFrame batch. - - Mirrors the row-building logic from - ``upload_embeddings_to_lancedb_inprocess`` in inprocess.py. - """ - rows: list = [] - for row in df.itertuples(index=False): - # Extract embedding - emb = None - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - payload = getattr(row, self._embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(self._embedding_key) - if not (isinstance(emb, list) and emb): - emb = None - if emb is None: - continue - - # Extract source path and page number - path = "" - page = -1 - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - - p = self._Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page}" if (pdf_basename and page >= 0) else "" - source_id = path or filename or pdf_basename - - metadata_obj = {"page_number": int(page) if page is not None else -1} - if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # These may be duplicated across exploded content rows; downstream - # summary logic should dedupe by (source_id, page_number). - pe_num = getattr(row, "page_elements_v3_num_detections", None) - if pe_num is not None: - try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) - except Exception: - pass - pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None - } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(row, ocr_col, None) - if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj = {"source_id": str(path)} - - row_out = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page) if page is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": self._json.dumps(metadata_obj, ensure_ascii=False), - "source": self._json.dumps(source_obj, ensure_ascii=False), - } - - if self._include_text: - t = getattr(row, self._text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - row_out["text"] = "" - - rows.append(row_out) - return rows + """Build LanceDB rows from a pandas DataFrame batch.""" + from nemo_retriever.ingest_modes.lancedb_utils import build_lancedb_rows + + return build_lancedb_rows( + df, + embedding_column=self._embedding_column, + embedding_key=self._embedding_key, + text_column=self._text_column, + include_text=self._include_text, + ) def __call__(self, batch_df: Any) -> Any: rows = self._build_rows(batch_df) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 0e09dd787..f4ed21596 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -634,67 +634,14 @@ def save_dataframe_to_disk_json(df: Any, *, output_directory: str) -> Any: return df -def _extract_embedding_from_row( - row: Any, - *, - embedding_column: str = "text_embeddings_1b_v2", - embedding_key: str = "embedding", -) -> Optional[List[float]]: - """ - Extract an embedding vector from a row (namedtuple or pd.Series). - - Supports: - - `metadata.embedding` (preferred if present) - - `embedding_column` payloads like `{"embedding": [...], ...}` (from `embed_text_1b_v2`) - """ - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - emb = meta.get("embedding") - if isinstance(emb, list) and emb: - return emb # type: ignore[return-value] - - payload = getattr(row, embedding_column, None) - if isinstance(payload, dict): - emb = payload.get(embedding_key) - if isinstance(emb, list) and emb: - return emb # type: ignore[return-value] - return None - - -def _extract_source_path_and_page(row: Any) -> Tuple[str, int]: - """ - Best-effort extract of source path and page number for LanceDB row metadata. - """ - path = "" - page = -1 - - v = getattr(row, "path", None) - if isinstance(v, str) and v.strip(): - path = v.strip() - - v = getattr(row, "page_number", None) - try: - if v is not None: - page = int(v) - except Exception: - pass - - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): - sp = meta.get("source_path") - if isinstance(sp, str) and sp.strip(): - path = sp.strip() - # Some schemas store page under content metadata; support if present. - cm = meta.get("content_metadata") - if isinstance(cm, dict) and page == -1: - h = cm.get("hierarchy") - if isinstance(h, dict) and "page" in h: - try: - page = int(h.get("page")) - except Exception: - pass - - return path, page +from nemo_retriever.ingest_modes.lancedb_utils import ( + build_lancedb_rows, + create_or_append_lancedb_table, + extract_embedding_from_row as _extract_embedding_from_row, + extract_source_path_and_page as _extract_source_path_and_page, + infer_vector_dim, + lancedb_schema, +) def upload_embeddings_to_lancedb_inprocess( @@ -744,112 +691,32 @@ def upload_embeddings_to_lancedb_inprocess( if not isinstance(df, pd.DataFrame): raise TypeError(f"upload_embeddings_to_lancedb_inprocess expects pandas.DataFrame, got {type(df)!r}") - rows: List[Dict[str, Any]] = [] - for r in df.itertuples(index=False): - emb = _extract_embedding_from_row(r, embedding_column=str(embedding_column), embedding_key=str(embedding_key)) - if emb is None: - continue - - path, page_number = _extract_source_path_and_page(r) - p = Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" - source_id = path or filename or pdf_basename - - # Provide fields compatible with `nemo_retriever.recall.core` which expects LanceDB hits - # to include JSON-encoded `metadata` and `source` strings. - metadata_obj: Dict[str, Any] = {"page_number": int(page_number) if page_number is not None else -1} - if pdf_page: - metadata_obj["pdf_page"] = pdf_page - # Persist per-page detection counters for end-of-run summaries. - # Mirrors batch.py so LanceDB-based summary reads also work. - pe_num = getattr(r, "page_elements_v3_num_detections", None) - if pe_num is not None: - try: - metadata_obj["page_elements_v3_num_detections"] = int(pe_num) - except Exception: - pass - pe_counts = getattr(r, "page_elements_v3_counts_by_label", None) - if isinstance(pe_counts, dict): - metadata_obj["page_elements_v3_counts_by_label"] = { - str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None - } - for ocr_col in ("table", "chart", "infographic"): - entries = getattr(r, ocr_col, None) - if isinstance(entries, list): - metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) - source_obj: Dict[str, Any] = {"source_id": str(path)} - - row_out: Dict[str, Any] = { - "vector": emb, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page_number) if page_number is not None else -1, - "source_id": str(source_id), - "path": str(path), - "metadata": json.dumps(metadata_obj, ensure_ascii=False), - "source": json.dumps(source_obj, ensure_ascii=False), - } - - if include_text: - t = getattr(r, text_column, None) - row_out["text"] = str(t) if isinstance(t, str) else "" - else: - # Still include the column for compatibility with the recall script's `.select(["text",...])`. - row_out["text"] = "" - - rows.append(row_out) + rows = build_lancedb_rows( + df, + embedding_column=str(embedding_column), + embedding_key=str(embedding_key), + text_column=str(text_column), + include_text=bool(include_text), + ) if not rows: print("No embeddings found to upload to LanceDB (no rows had embeddings).") return df - # Infer vector dim from first row. - dim = 0 - for rr in rows: - v = rr.get("vector") - if isinstance(v, list) and v: - dim = int(len(v)) - break + dim = infer_vector_dim(rows) if dim <= 0: raise ValueError("Failed to infer embedding dimension from DataFrame rows.") try: import lancedb # type: ignore - import pyarrow as pa # type: ignore except Exception as e: raise RuntimeError( "LanceDB upload requested but dependencies are missing. Install `lancedb` and `pyarrow`." ) from e db = lancedb.connect(uri=str(lancedb_uri)) - - fields = [ - pa.field("vector", pa.list_(pa.float32(), dim)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - # Compatibility columns expected by `nemo_retriever.recall.core`: - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - schema = pa.schema(fields) - - # Overwrite vs append. - if overwrite: - table = db.create_table(str(table_name), data=list(rows), schema=schema, mode="overwrite") - else: - try: - table = db.open_table(str(table_name)) - table.add(list(rows)) - except Exception: - table = db.create_table(str(table_name), data=list(rows), schema=schema, mode="create") + schema = lancedb_schema(dim) + table = create_or_append_lancedb_table(db, str(table_name), rows, schema, overwrite=overwrite) if create_index: # LanceDB IVF-based indexes train k-means with K=num_partitions. K must be < N vectors. diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py b/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py new file mode 100644 index 000000000..f74f787eb --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/lancedb_utils.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared LanceDB row construction, schema, and table helpers. + +Consolidates the duplicated logic that previously lived independently in +``inprocess.py`` (``upload_embeddings_to_lancedb_inprocess``) and +``batch.py`` (``_LanceDBWriteActor._build_rows``). +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def extract_embedding_from_row( + row: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", +) -> Optional[List[float]]: + """Extract an embedding vector from a row (namedtuple or pd.Series). + + Supports: + - ``metadata.embedding`` (preferred if present) + - *embedding_column* payloads like ``{"embedding": [...], ...}`` + """ + meta = getattr(row, "metadata", None) + if isinstance(meta, dict): + emb = meta.get("embedding") + if isinstance(emb, list) and emb: + return emb # type: ignore[return-value] + + payload = getattr(row, embedding_column, None) + if isinstance(payload, dict): + emb = payload.get(embedding_key) + if isinstance(emb, list) and emb: + return emb # type: ignore[return-value] + return None + + +def extract_source_path_and_page(row: Any) -> Tuple[str, int]: + """Best-effort extract of source path and page number from a row.""" + path = "" + page = -1 + + v = getattr(row, "path", None) + if isinstance(v, str) and v.strip(): + path = v.strip() + + v = getattr(row, "page_number", None) + try: + if v is not None: + page = int(v) + except Exception: + pass + + meta = getattr(row, "metadata", None) + if isinstance(meta, dict): + sp = meta.get("source_path") + if isinstance(sp, str) and sp.strip(): + path = sp.strip() + cm = meta.get("content_metadata") + if isinstance(cm, dict) and page == -1: + h = cm.get("hierarchy") + if isinstance(h, dict) and "page" in h: + try: + page = int(h.get("page")) + except Exception: + pass + + return path, page + + +def _build_detection_metadata(row: Any) -> Dict[str, Any]: + """Extract per-page detection counters from a row for LanceDB metadata.""" + out: Dict[str, Any] = {} + + pe_num = getattr(row, "page_elements_v3_num_detections", None) + if pe_num is not None: + try: + out["page_elements_v3_num_detections"] = int(pe_num) + except Exception: + pass + + pe_counts = getattr(row, "page_elements_v3_counts_by_label", None) + if isinstance(pe_counts, dict): + out["page_elements_v3_counts_by_label"] = { + str(k): int(v) for k, v in pe_counts.items() if isinstance(k, str) and v is not None + } + + for ocr_col in ("table", "chart", "infographic"): + entries = getattr(row, ocr_col, None) + if isinstance(entries, list): + out[f"ocr_{ocr_col}_detections"] = int(len(entries)) + + return out + + +def build_lancedb_row( + row: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> Optional[Dict[str, Any]]: + """Build a single LanceDB-ready dict from a DataFrame row. + + Returns ``None`` when no embedding is found in the row. + """ + emb = extract_embedding_from_row(row, embedding_column=embedding_column, embedding_key=embedding_key) + if emb is None: + return None + + path, page_number = extract_source_path_and_page(row) + p = Path(path) if path else None + filename = p.name if p is not None else "" + pdf_basename = p.stem if p is not None else "" + pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" + source_id = path or filename or pdf_basename + + metadata_obj: Dict[str, Any] = {"page_number": int(page_number) if page_number is not None else -1} + if pdf_page: + metadata_obj["pdf_page"] = pdf_page + metadata_obj.update(_build_detection_metadata(row)) + + source_obj: Dict[str, Any] = {"source_id": str(path)} + + row_out: Dict[str, Any] = { + "vector": emb, + "pdf_page": pdf_page, + "filename": filename, + "pdf_basename": pdf_basename, + "page_number": int(page_number) if page_number is not None else -1, + "source_id": str(source_id), + "path": str(path), + "metadata": json.dumps(metadata_obj, ensure_ascii=False), + "source": json.dumps(source_obj, ensure_ascii=False), + } + + if include_text: + t = getattr(row, text_column, None) + row_out["text"] = str(t) if isinstance(t, str) else "" + else: + row_out["text"] = "" + + return row_out + + +def build_lancedb_rows( + df: Any, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Build LanceDB rows from a pandas DataFrame. + + Iterates with ``itertuples`` and delegates to :func:`build_lancedb_row`. + Rows without an embedding are silently skipped. + """ + rows: List[Dict[str, Any]] = [] + for r in df.itertuples(index=False): + row_out = build_lancedb_row( + r, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) + if row_out is not None: + rows.append(row_out) + return rows + + +def lancedb_schema(vector_dim: int) -> Any: + """Return a PyArrow schema for the standard LanceDB table layout.""" + import pyarrow as pa # type: ignore + + return pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), vector_dim)), + pa.field("pdf_page", pa.string()), + pa.field("filename", pa.string()), + pa.field("pdf_basename", pa.string()), + pa.field("page_number", pa.int32()), + pa.field("source_id", pa.string()), + pa.field("path", pa.string()), + pa.field("text", pa.string()), + pa.field("metadata", pa.string()), + pa.field("source", pa.string()), + ] + ) + + +def infer_vector_dim(rows: List[Dict[str, Any]]) -> int: + """Return the embedding dimension from the first row that has a vector.""" + for r in rows: + v = r.get("vector") + if isinstance(v, list) and v: + return len(v) + return 0 + + +def create_or_append_lancedb_table( + db: Any, + table_name: str, + rows: List[Dict[str, Any]], + schema: Any, + overwrite: bool = True, +) -> Any: + """Create or append to a LanceDB table, returning the table object.""" + if overwrite: + return db.create_table(str(table_name), data=list(rows), schema=schema, mode="overwrite") + + try: + table = db.open_table(str(table_name)) + table.add(list(rows)) + return table + except Exception: + return db.create_table(str(table_name), data=list(rows), schema=schema, mode="create") diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py new file mode 100644 index 000000000..cf4d3cdcc --- /dev/null +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.ingest_modes.lancedb_utils.""" + +import json +from types import SimpleNamespace + +import pytest + +from nemo_retriever.ingest_modes.lancedb_utils import ( + build_lancedb_row, + build_lancedb_rows, + create_or_append_lancedb_table, + extract_embedding_from_row, + extract_source_path_and_page, + infer_vector_dim, + lancedb_schema, +) + + +class TestExtractEmbeddingFromRow: + def test_from_metadata(self): + row = SimpleNamespace(metadata={"embedding": [1.0, 2.0, 3.0]}) + assert extract_embedding_from_row(row) == [1.0, 2.0, 3.0] + + def test_from_embedding_column(self): + row = SimpleNamespace( + metadata=None, + text_embeddings_1b_v2={"embedding": [4.0, 5.0]}, + ) + assert extract_embedding_from_row(row) == [4.0, 5.0] + + def test_custom_column(self): + row = SimpleNamespace(metadata=None, my_col={"vec": [6.0]}) + assert extract_embedding_from_row(row, embedding_column="my_col", embedding_key="vec") == [6.0] + + def test_returns_none_when_missing(self): + row = SimpleNamespace(metadata=None) + assert extract_embedding_from_row(row) is None + + def test_empty_embedding_returns_none(self): + row = SimpleNamespace(metadata={"embedding": []}) + assert extract_embedding_from_row(row) is None + + +class TestExtractSourcePathAndPage: + def test_from_direct_attrs(self): + row = SimpleNamespace(path="/docs/file.pdf", page_number=3, metadata=None) + assert extract_source_path_and_page(row) == ("/docs/file.pdf", 3) + + def test_from_metadata_source_path(self): + row = SimpleNamespace(path="", page_number=None, metadata={"source_path": "/meta/path.pdf"}) + assert extract_source_path_and_page(row) == ("/meta/path.pdf", -1) + + def test_from_content_metadata_hierarchy(self): + row = SimpleNamespace( + path="", + page_number=None, + metadata={"content_metadata": {"hierarchy": {"page": 7}}}, + ) + path, page = extract_source_path_and_page(row) + assert page == 7 + + def test_defaults_when_missing(self): + row = SimpleNamespace() + assert extract_source_path_and_page(row) == ("", -1) + + +class TestBuildLancedbRow: + def _row(self, **kwargs): + defaults = { + "metadata": {"embedding": [0.1, 0.2]}, + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello world", + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + def test_returns_dict_with_expected_keys(self): + result = build_lancedb_row(self._row()) + assert result is not None + assert set(result.keys()) == { + "vector", "pdf_page", "filename", "pdf_basename", + "page_number", "source_id", "path", "metadata", "source", "text", + } + + def test_vector_extracted(self): + result = build_lancedb_row(self._row()) + assert result["vector"] == [0.1, 0.2] + + def test_path_fields(self): + result = build_lancedb_row(self._row()) + assert result["filename"] == "test.pdf" + assert result["pdf_basename"] == "test" + assert result["pdf_page"] == "test_1" + + def test_text_included(self): + result = build_lancedb_row(self._row()) + assert result["text"] == "hello world" + + def test_text_excluded(self): + result = build_lancedb_row(self._row(), include_text=False) + assert result["text"] == "" + + def test_metadata_json(self): + result = build_lancedb_row(self._row()) + meta = json.loads(result["metadata"]) + assert meta["page_number"] == 1 + assert meta["pdf_page"] == "test_1" + + def test_returns_none_when_no_embedding(self): + row = SimpleNamespace(metadata=None, path="/x.pdf", page_number=1, text="hi") + assert build_lancedb_row(row) is None + + def test_detection_metadata_included(self): + row = self._row( + page_elements_v3_num_detections=5, + page_elements_v3_counts_by_label={"text": 3, "figure": 2}, + table=[{}, {}], + ) + result = build_lancedb_row(row) + meta = json.loads(result["metadata"]) + assert meta["page_elements_v3_num_detections"] == 5 + assert meta["page_elements_v3_counts_by_label"] == {"text": 3, "figure": 2} + assert meta["ocr_table_detections"] == 2 + + +class TestBuildLancedbRows: + def test_filters_rows_without_embeddings(self): + import pandas as pd + + df = pd.DataFrame([ + {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, + {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, + ]) + rows = build_lancedb_rows(df) + assert len(rows) == 1 + assert rows[0]["vector"] == [1.0] + + +class TestLancedbSchema: + def test_returns_schema_with_correct_fields(self): + schema = lancedb_schema(768) + names = [f.name for f in schema] + assert "vector" in names + assert "text" in names + assert "metadata" in names + assert "source" in names + assert len(names) == 10 + + +class TestInferVectorDim: + def test_returns_dim(self): + assert infer_vector_dim([{"vector": [1.0, 2.0, 3.0]}]) == 3 + + def test_returns_zero_when_empty(self): + assert infer_vector_dim([]) == 0 + assert infer_vector_dim([{"vector": []}]) == 0 + + +class TestCreateOrAppendLancedbTable: + def test_overwrite_calls_create(self): + from unittest.mock import MagicMock + + db = MagicMock() + schema = MagicMock() + rows = [{"a": 1}] + create_or_append_lancedb_table(db, "test", rows, schema, overwrite=True) + db.create_table.assert_called_once_with("test", data=[{"a": 1}], schema=schema, mode="overwrite") + + def test_append_opens_then_adds(self): + from unittest.mock import MagicMock + + db = MagicMock() + table = MagicMock() + db.open_table.return_value = table + rows = [{"a": 1}] + result = create_or_append_lancedb_table(db, "t", rows, MagicMock(), overwrite=False) + db.open_table.assert_called_once_with("t") + table.add.assert_called_once() + assert result is table + + def test_append_falls_back_to_create(self): + from unittest.mock import MagicMock + + db = MagicMock() + db.open_table.side_effect = Exception("not found") + schema = MagicMock() + rows = [{"a": 1}] + create_or_append_lancedb_table(db, "t", rows, schema, overwrite=False) + db.create_table.assert_called_once_with("t", data=[{"a": 1}], schema=schema, mode="create") From 13933f7cd9d9cdbed92bafb94a4a97bcaaebebfc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 19:48:58 -0500 Subject: [PATCH 3/6] Fix lint: remove unused imports, apply black formatting - Remove unused Path import and unused _extract_* aliases from inprocess.py - Remove unused pytest import from test_lancedb_utils.py - Apply black formatting to set literal and DataFrame constructor Made-with: Cursor --- .../nemo_retriever/ingest_modes/inprocess.py | 3 --- nemo_retriever/tests/test_lancedb_utils.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 0191abc4d..c07d2f44c 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -23,7 +23,6 @@ from datetime import datetime, timezone from io import BytesIO from collections.abc import Callable, Iterator -from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -637,8 +636,6 @@ def save_dataframe_to_disk_json(df: Any, *, output_directory: str) -> Any: from nemo_retriever.ingest_modes.lancedb_utils import ( build_lancedb_rows, create_or_append_lancedb_table, - extract_embedding_from_row as _extract_embedding_from_row, - extract_source_path_and_page as _extract_source_path_and_page, infer_vector_dim, lancedb_schema, ) diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py index cf4d3cdcc..36a5296c3 100644 --- a/nemo_retriever/tests/test_lancedb_utils.py +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -7,8 +7,6 @@ import json from types import SimpleNamespace -import pytest - from nemo_retriever.ingest_modes.lancedb_utils import ( build_lancedb_row, build_lancedb_rows, @@ -83,8 +81,16 @@ def test_returns_dict_with_expected_keys(self): result = build_lancedb_row(self._row()) assert result is not None assert set(result.keys()) == { - "vector", "pdf_page", "filename", "pdf_basename", - "page_number", "source_id", "path", "metadata", "source", "text", + "vector", + "pdf_page", + "filename", + "pdf_basename", + "page_number", + "source_id", + "path", + "metadata", + "source", + "text", } def test_vector_extracted(self): @@ -132,10 +138,12 @@ class TestBuildLancedbRows: def test_filters_rows_without_embeddings(self): import pandas as pd - df = pd.DataFrame([ - {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, - {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, - ]) + df = pd.DataFrame( + [ + {"metadata": {"embedding": [1.0]}, "path": "/a.pdf", "page_number": 1, "text": "a"}, + {"metadata": {}, "path": "/b.pdf", "page_number": 1, "text": "b"}, + ] + ) rows = build_lancedb_rows(df) assert len(rows) == 1 assert rows[0]["vector"] == [1.0] From 894434f554bb10a2cab041a1c4c297a7decc08e1 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 19:51:47 -0500 Subject: [PATCH 4/6] Fix test collection: stub heavy sibling modules before lancedb_utils import The ingest_modes __init__.py eagerly imports batch/fused/inprocess/online which pull in ray, torch, etc. Pre-populate sys.modules with MagicMock stubs so lancedb_utils tests can run in lightweight CI without those deps. Made-with: Cursor --- nemo_retriever/tests/test_lancedb_utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/nemo_retriever/tests/test_lancedb_utils.py b/nemo_retriever/tests/test_lancedb_utils.py index 36a5296c3..e037083dc 100644 --- a/nemo_retriever/tests/test_lancedb_utils.py +++ b/nemo_retriever/tests/test_lancedb_utils.py @@ -5,9 +5,22 @@ """Unit tests for nemo_retriever.ingest_modes.lancedb_utils.""" import json +import sys from types import SimpleNamespace - -from nemo_retriever.ingest_modes.lancedb_utils import ( +from unittest.mock import MagicMock + +# The ingest_modes __init__.py eagerly imports batch/fused/inprocess/online, +# which pull in ray, torch, etc. Stub them so lancedb_utils can be imported +# in lightweight CI (matching the pattern in test_multimodal_embed.py). +for _mod_name in [ + "nemo_retriever.ingest_modes.batch", + "nemo_retriever.ingest_modes.fused", + "nemo_retriever.ingest_modes.inprocess", + "nemo_retriever.ingest_modes.online", +]: + sys.modules.setdefault(_mod_name, MagicMock()) + +from nemo_retriever.ingest_modes.lancedb_utils import ( # noqa: E402 build_lancedb_row, build_lancedb_rows, create_or_append_lancedb_table, From 7e63f8aedd0e74528463328079f473d337f8a8b4 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 18:03:39 -0500 Subject: [PATCH 5/6] Move duplicated recall helpers to recall/core.py and examples/common.py Centralises gold_to_doc_page, hit_key_and_distance, estimate_processed_pages, and print_pages_per_second that were duplicated across batch, inprocess, online, and fused pipeline examples. Fixes broken imports in fused_pipeline.py that referenced non-existent functions in batch_pipeline.py. Made-with: Cursor --- .../nemo_retriever/examples/batch_pipeline.py | 73 ++------------- .../src/nemo_retriever/examples/common.py | 51 +++++++++++ .../nemo_retriever/examples/fused_pipeline.py | 30 ++++--- .../examples/inprocess_pipeline.py | 90 +++---------------- .../examples/online_pipeline.py | 48 +++------- .../src/nemo_retriever/recall/core.py | 31 +++++++ 6 files changed, 127 insertions(+), 196 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/examples/common.py diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 36cb58034..137f31bc5 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -257,26 +257,15 @@ def _write_detection_summary(path: Path, summary: Optional[dict]) -> None: target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only; excludes Ray startup and recall): {pps:.2f}") def _ensure_lancedb_table(uri: str, table_name: str) -> None: - """ - Ensure the local LanceDB URI exists and table can be opened. + """Ensure the local LanceDB URI exists and table can be opened. Creates an empty table with the expected schema if it does not exist yet. """ - # Local path URI in this pipeline. + from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema + Path(uri).mkdir(parents=True, exist_ok=True) db = _lancedb().connect(uri) @@ -288,63 +277,11 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None: import pyarrow as pa # type: ignore - schema = pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 2048)), - pa.field("pdf_page", pa.string()), - pa.field("filename", pa.string()), - pa.field("pdf_basename", pa.string()), - pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), - pa.field("path", pa.string()), - pa.field("text", pa.string()), - pa.field("metadata", pa.string()), - pa.field("source", pa.string()), - ] - ) - empty = pa.table( - { - "vector": [], - "pdf_page": [], - "filename": [], - "pdf_basename": [], - "page_number": [], - "source_id": [], - "path": [], - "text": [], - "metadata": [], - "source": [], - }, - schema=schema, - ) + schema = lancedb_schema(2048) + empty = pa.table({f.name: [] for f in schema}, schema=schema) db.create_table(table_name, data=empty, schema=schema, mode="create") -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None - return key, dist - - @app.command() def main( ctx: typer.Context, diff --git a/nemo_retriever/src/nemo_retriever/examples/common.py b/nemo_retriever/src/nemo_retriever/examples/common.py new file mode 100644 index 000000000..70a165bf6 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/examples/common.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers used by multiple example pipeline scripts.""" + +from __future__ import annotations + +from typing import Optional + + +def estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: + """Estimate pages processed by counting unique (source_id, page_number) pairs. + + Falls back to table row count if page-level fields are unavailable. + """ + try: + import lancedb # type: ignore + + db = lancedb.connect(uri) + table = db.open_table(table_name) + except Exception: + return None + + try: + df = table.to_pandas()[["source_id", "page_number"]] + return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + except Exception: + try: + return int(table.count_rows()) + except Exception: + return None + + +def print_pages_per_second( + processed_pages: Optional[int], + ingest_elapsed_s: float, + *, + label: str = "ingest only", +) -> None: + """Print a throughput summary line.""" + if ingest_elapsed_s <= 0: + print("Pages/sec: unavailable (ingest elapsed time was non-positive).") + return + if processed_pages is None: + print(f"Pages/sec: unavailable (could not estimate processed pages). Ingest time: {ingest_elapsed_s:.2f}s") + return + + pps = processed_pages / ingest_elapsed_s + print(f"Pages processed: {processed_pages}") + print(f"Pages/sec ({label}): {pps:.2f}") diff --git a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py index b378ba69f..da327ff1f 100644 --- a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py @@ -27,18 +27,20 @@ from nemo_retriever.examples.batch_pipeline import ( LANCEDB_TABLE, LANCEDB_URI, + _collect_detection_summary, _configure_logging, _ensure_lancedb_table, - _estimate_processed_pages, - _gold_to_doc_page, - _hit_key_and_distance, - _is_hit_at_k, _print_detection_summary, - _print_pages_per_second, _write_detection_summary, - _collect_detection_summary, ) -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -242,7 +244,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) detection_summary = _collect_detection_summary(lancedb_uri, LANCEDB_TABLE) print("Extraction complete.") _print_detection_summary(detection_summary) @@ -255,7 +257,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(lancedb_uri) @@ -277,7 +279,7 @@ def main( try: if int(table.count_rows()) == 0: print(f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return except Exception: pass @@ -305,16 +307,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: print(f"\nQuery {i}: {q}") @@ -345,7 +347,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) finally: # Restore real stdio before closing the mirror file so exception hooks # and late flushes never write to a closed stream wrapper. diff --git a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py index 238da6f2c..cec6ec3bf 100644 --- a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py @@ -7,7 +7,6 @@ Run with: uv run python -m nemo_retriever.examples.inprocess_pipeline """ -import json import time from pathlib import Path from typing import Optional @@ -15,12 +14,19 @@ import lancedb import typer from nemo_retriever import create_ingestor +from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -28,74 +34,6 @@ LANCEDB_TABLE = "nv-ingest" -def _estimate_processed_pages(uri: str, table_name: str) -> Optional[int]: - """ - Estimate pages processed by counting unique (source_id, page_number) pairs. - - Falls back to table row count if page-level fields are unavailable. - """ - try: - db = lancedb.connect(uri) - table = db.open_table(table_name) - except Exception: - return None - - try: - df = table.to_pandas()[["source_id", "page_number"]] - return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) - except Exception: - try: - return int(table.count_rows()) - except Exception: - return None - - -def _print_pages_per_second(processed_pages: Optional[int], ingest_elapsed_s: float) -> None: - if ingest_elapsed_s <= 0: - print("Pages/sec: unavailable (ingest elapsed time was non-positive).") - return - if processed_pages is None: - print("Pages/sec: unavailable (could not estimate processed pages). " f"Ingest time: {ingest_elapsed_s:.2f}s") - return - - pps = processed_pages / ingest_elapsed_s - print(f"Pages processed: {processed_pages}") - print(f"Pages/sec (ingest only): {pps:.2f}") - - -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -388,7 +326,7 @@ def main( ) ) ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = _estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) + processed_pages = estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) print("Extraction complete.") # --------------------------------------------------------------------------- @@ -397,7 +335,7 @@ def main( query_csv = Path(query_csv) if not query_csv.exists(): print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) return db = lancedb.connect(f"./{LANCEDB_URI}") @@ -432,16 +370,16 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ( @@ -482,7 +420,7 @@ def main( print("\nRecall metrics (matching nemo_retriever.recall.core):") for k, v in metrics.items(): print(f" {k}: {v:.4f}") - _print_pages_per_second(processed_pages, ingest_elapsed_s) + print_pages_per_second(processed_pages, ingest_elapsed_s) if __name__ == "__main__": diff --git a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py index d72fea7cf..f2da17b5d 100644 --- a/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/online_pipeline.py @@ -15,7 +15,6 @@ --run-mode online --base-url http://localhost:7670 """ -import json from pathlib import Path import lancedb @@ -27,7 +26,13 @@ from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.recall.core import ( + RecallConfig, + gold_to_doc_page, + hit_key_and_distance, + is_hit_at_k, + retrieve_and_score, +) app = typer.Typer() @@ -35,39 +40,6 @@ LANCEDB_TABLE = "nv-ingest" -def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: - s = str(golden_key) - if "_" not in s: - return s, "" - doc, page = s.rsplit("_", 1) - return doc, page - - -def _is_hit_at_k(golden_key: str, retrieved_keys: list[str], k: int) -> bool: - doc, page = _gold_to_doc_page(golden_key) - specific_page = f"{doc}_{page}" - entire_document = f"{doc}_-1" - top = (retrieved_keys or [])[: int(k)] - return (specific_page in top) or (entire_document in top) - - -def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: - try: - res = json.loads(hit.get("metadata", "{}")) - source = json.loads(hit.get("source", "{}")) - except Exception: - return None, None - - source_id = source.get("source_id") - page_number = res.get("page_number") - if not source_id or page_number is None: - return None, float(hit.get("_distance")) if "_distance" in hit else None - - key = f"{Path(str(source_id)).stem}_{page_number}" - dist = float(hit.get("_distance")) if "_distance" in hit else None - return key, dist - - @app.command() def main( input_path: Path = typer.Argument( @@ -236,14 +208,14 @@ def main( _raw_hits, ) ): - doc, page = _gold_to_doc_page(g) + doc, page = gold_to_doc_page(g) scored_hits: list[tuple[str, float | None]] = [] for h in hits: - key, dist = _hit_key_and_distance(h) + key, dist = hit_key_and_distance(h) if key: scored_hits.append((key, dist)) top_keys = [k for (k, _d) in scored_hits] - hit = _is_hit_at_k(g, top_keys, cfg.top_k) + hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") if not no_recall_details: ext = ".txt" if input_type == "txt" else (".docx" if input_type == "doc" else ".pdf") typer.echo(f"\nQuery {i}: {q}") diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index b684a3f6d..95c540850 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -299,6 +299,37 @@ def is_hit_at_k(golden_key: str, retrieved: Sequence[str], k: int, *, match_mode return _is_hit(str(golden_key), list(retrieved), int(k), match_mode=str(match_mode)) +def gold_to_doc_page(golden_key: str) -> tuple[str, str]: + """Split a golden key like ``"docname_page"`` into ``(doc, page)``.""" + s = str(golden_key) + if "_" not in s: + return s, "" + doc, page = s.rsplit("_", 1) + return doc, page + + +def hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: + """Extract ``(pdf_page key, distance)`` from a single LanceDB hit dict. + + Supports both ``_distance`` and ``_score`` fields for compatibility across + LanceDB query types (vector vs hybrid). + """ + try: + res = json.loads(hit.get("metadata", "{}")) + source = json.loads(hit.get("source", "{}")) + except Exception: + return None, None + + source_id = source.get("source_id") + page_number = res.get("page_number") + if not source_id or page_number is None: + return None, float(hit.get("_distance")) if "_distance" in hit else None + + key = f"{Path(str(source_id)).stem}_{page_number}" + dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None + return key, dist + + def _recall_at_k(gold: List[str], retrieved: List[List[str]], k: int, *, match_mode: str) -> float: hits = sum(is_hit_at_k(g, r, k, match_mode=match_mode) for g, r in zip(gold, retrieved)) return hits / max(1, len(gold)) From ad43c7146f7f8e27dcd88b78e2138bfeb4893c50 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 6 Mar 2026 20:32:10 -0500 Subject: [PATCH 6/6] linter fixes --- nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 137f31bc5..7846896ec 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -257,8 +257,6 @@ def _write_detection_summary(path: Path, summary: Optional[dict]) -> None: target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") - - def _ensure_lancedb_table(uri: str, table_name: str) -> None: """Ensure the local LanceDB URI exists and table can be opened.