diff --git a/nemo_retriever/Dockerfile b/nemo_retriever/Dockerfile index 30c8dc38d..a3b536abc 100644 --- a/nemo_retriever/Dockerfile +++ b/nemo_retriever/Dockerfile @@ -70,9 +70,24 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv python install 3.12 \ && uv venv --python 3.12 /opt/retriever_runtime -ENV VIRTUAL_ENV=/opt/retriever_runtime -ENV PATH=/opt/retriever_runtime/bin:/root/.local/bin:$PATH -ENV LD_LIBRARY_PATH=/opt/retriever_runtime/lib:${LD_LIBRARY_PATH} +RUN --mount=type=cache,target=/root/.cache/uv \ + . /opt/retriever_runtime/bin/activate \ + && wget -qO- https://bootstrap.pypa.io/get-pip.py | python - + +RUN --mount=type=cache,target=/root/.cache/uv \ + . /opt/retriever_runtime/bin/activate \ + && pip install --no-cache-dir openai + +RUN --mount=type=cache,target=/root/.cache/uv \ + . /opt/retriever_runtime/bin/activate \ + && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ + && dpkg -i cuda-keyring_1.1-1_all.deb \ + && apt update && apt-get --fix-broken install -y && apt-get -y install cuda-toolkit-13-0 + + +# ENV VIRTUAL_ENV=/opt/retriever_runtime +# ENV PATH=/opt/retriever_runtime/bin:/root/.local/bin:$PATH +# ENV LD_LIBRARY_PATH=/opt/retriever_runtime/lib:${LD_LIBRARY_PATH} # --------------------------------------------------------------------------- # Install nemo_retriever and path deps (build context = repo root) @@ -84,18 +99,17 @@ WORKDIR /workspace # Unbuffered stdout/stderr so CLI output appears when run without a TTY (e.g. docker run without -it) ENV PYTHONUNBUFFERED=1 -COPY nemo_retriever nemo_retriever -COPY src src -COPY api api -COPY client client +# COPY nemo_retriever nemo_retriever +# COPY src src +# COPY api api +# COPY client client # Use base stage's venv at /opt/retriever_runtime; install nemo_retriever in editable mode (path deps: ../src, ../api, ../client) -SHELL ["/bin/bash", "-c"] -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ - . /opt/retriever_runtime/bin/activate \ - && uv pip install -e ./nemo_retriever +# SHELL ["/bin/bash", "-c"] +# RUN --mount=type=cache,target=/root/.cache/pip \ +# --mount=type=cache,target=/root/.cache/uv \ +# . /opt/retriever_runtime/bin/activate \ +# && uv pip install -e ./nemo_retriever # Default: run in-process pipeline (help if no args) -ENTRYPOINT ["/opt/retriever_runtime/bin/python", "-m", "nemo_retriever.examples.inprocess_pipeline"] -CMD ["--help"] +CMD ["/bin/bash"] diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index b7faf73d5..e50e3c636 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -25,7 +25,6 @@ from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import TextChunkParams -from nemo_retriever.params import VdbUploadParams from nemo_retriever.recall.core import RecallConfig, retrieve_and_score from nemo_retriever.utils.input_files import resolve_input_patterns @@ -192,8 +191,6 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None: Creates an empty table with the expected schema if it does not exist yet. """ - from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema - Path(uri).mkdir(parents=True, exist_ok=True) db = _lancedb().connect(uri) @@ -205,11 +202,56 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None: import pyarrow as pa # type: ignore - schema = lancedb_schema(2048) - empty = pa.table({f.name: [] for f in schema}, schema=schema) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2048)), + pa.field("pdf_page", pa.string()), + pa.field("filename", pa.string()), + pa.field("pdf_basename", pa.string()), + pa.field("page_number", pa.int32()), + pa.field("source", pa.string()), + pa.field("path", pa.string()), + pa.field("text", pa.string()), + pa.field("metadata", pa.string()), + ] + ) + empty = pa.table( + { + "vector": [], + "pdf_page": [], + "filename": [], + "pdf_basename": [], + "page_number": [], + "source": [], + "path": [], + "text": [], + "metadata": [], + }, + schema=schema, + ) db.create_table(table_name, data=empty, schema=schema, mode="create") +def _gold_to_doc_page(golden_key: str) -> tuple[str, str]: + s = str(golden_key) + if "_" not in s: + return s, "" + doc, page = s.rsplit("_", 1) + return doc, page + + +def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: + + source_id = hit.get("source_id") + page_number = hit.get("page_number") + if not source_id or page_number is None: + return None, float(hit.get("_distance")) if "_distance" in hit else None + + key = f"{Path(str(source_id)).stem}_{page_number}" + dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None + return key, dist + + @app.command() def main( ctx: typer.Context, @@ -541,17 +583,6 @@ def main( embed_granularity=embed_granularity, ) ) - .vdb_upload( - VdbUploadParams( - lancedb={ - "lancedb_uri": lancedb_uri, - "table_name": LANCEDB_TABLE, - "overwrite": True, - "create_index": True, - "hybrid": hybrid, - } - ) - ) ) elif input_type == "html": ingestor = ( @@ -567,17 +598,6 @@ def main( embed_granularity=embed_granularity, ) ) - .vdb_upload( - VdbUploadParams( - lancedb={ - "lancedb_uri": lancedb_uri, - "table_name": LANCEDB_TABLE, - "overwrite": True, - "create_index": True, - "hybrid": hybrid, - } - ) - ) ) elif input_type == "doc": ingestor = ( @@ -632,17 +652,6 @@ def main( }, ) ) - .vdb_upload( - VdbUploadParams( - lancedb={ - "lancedb_uri": lancedb_uri, - "table_name": LANCEDB_TABLE, - "overwrite": True, - "create_index": True, - "hybrid": hybrid, - } - ) - ) ) else: ingestor = ( @@ -698,17 +707,6 @@ def main( }, ) ) - .vdb_upload( - VdbUploadParams( - lancedb={ - "lancedb_uri": lancedb_uri, - "table_name": LANCEDB_TABLE, - "overwrite": True, - "create_index": True, - "hybrid": hybrid, - } - ) - ) ) logger.info("Running extraction...") @@ -725,8 +723,11 @@ def main( .materialize() ) - if hasattr(ingestor, "_create_lancedb_index"): - ingestor._create_lancedb_index() + # if hasattr(ingestor, "_create_lancedb_index"): + # ingestor._create_lancedb_index() + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + + handle_lancedb(ingest_results.take_all(), lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") ingest_elapsed_s = time.perf_counter() - ingest_start rows_processed = _count_materialized_rows(ingest_results) diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index 95c540850..d5174b968 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -4,12 +4,12 @@ from __future__ import annotations -import json import logging import time from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple +import json logger = logging.getLogger(__name__) @@ -226,7 +226,7 @@ def _search_lancedb( .text(text) .nprobes(effective_nprobes) .refine_factor(refine_factor) - .select(["text", "metadata", "source"]) + .select(["text", "metadata", "source", "page_number"]) .limit(top_k) .rerank(RRFReranker()) .to_list() @@ -236,7 +236,7 @@ def _search_lancedb( table.search(q, vector_column_name=vector_column_name) .nprobes(effective_nprobes) .refine_factor(refine_factor) - .select(["text", "metadata", "source", "_distance"]) + .select(["text", "metadata", "source", "page_number", "_distance"]) .limit(top_k) .to_list() ) @@ -250,12 +250,13 @@ def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]: for hits in raw_hits: keys: List[str] = [] for h in hits: - res = json.loads(h["metadata"]) - source = json.loads(h["source"]) + page_number = h["page_number"] + source = h["source"] # Prefer explicit `pdf_page` column; fall back to derived form. - if res.get("page_number") is not None and source.get("source_id"): - filename = Path(source["source_id"]).stem - keys.append(filename + "_" + str(res["page_number"])) + # if res.get("page_number") is not None and source.get("source_id"): + if page_number is not None and source: + filename = Path(source).stem + keys.append(f"{filename}_{str(page_number)}") else: logger.warning( "Skipping hit with missing page_number or source_id: metadata=%s source=%s", diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index ad392bf15..737e096ba 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -9,9 +9,11 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple # noqa: F401 +from datetime import timedelta from nv_ingest_client.util.vdb.lancedb import LanceDB import pandas as pd +import lancedb logger = logging.getLogger(__name__) @@ -117,7 +119,7 @@ def _extract_page_number(meta: Dict[str, Any]) -> int: return -1 -def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]: +def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Transform an embeddings-enriched primitives DataFrame into LanceDB rows. @@ -131,7 +133,7 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]: """ out: List[Dict[str, Any]] = [] - for _, row in df.iterrows(): + for row in rows: meta = row.get("metadata") if not isinstance(meta, dict): continue @@ -146,9 +148,12 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]: embedding = list(embedding) # type: ignore[arg-type] except Exception: continue - - path, source_id = _extract_source_path_and_id(meta) - page_number = _extract_page_number(meta) + meta.pop("embedding", None) # Remove embedding from metadata to save space in LanceDB. + # path, source_id = _extract_source_path_and_id(meta) + path = row.get("path", "") + source_id = meta.get("source_path", path) + # page_number = _extract_page_number(meta) + page_number = row.get("page_number", -1) p = Path(path) if path else None filename = p.name if p is not None else "" pdf_basename = p.stem if p is not None else "" @@ -164,8 +169,10 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]: "filename": filename, "pdf_basename": pdf_basename, "page_number": int(page_number), - "source_id": source_id, + "source": source_id, "path": path, + "text": row.get("text", ""), + "metadata": str(meta), } ) @@ -203,6 +210,9 @@ def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = " exc_info=True, ) + for index_stub in table.list_indices(): + table.wait_for_index([index_stub.name], timeout=timedelta(seconds=600)) + def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig) -> None: if not rows: @@ -231,8 +241,10 @@ def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig pa.field("filename", pa.string()), pa.field("pdf_basename", pa.string()), pa.field("page_number", pa.int32()), - pa.field("source_id", pa.string()), + pa.field("source", pa.string()), pa.field("path", pa.string()), + pa.field("text", pa.string()), + pa.field("metadata", pa.string()), ] ) @@ -323,3 +335,28 @@ def write_text_embeddings_dir_to_lancedb( # "rows_written": len(all_rows), "lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } + + +def handle_lancedb( + rows: Path, + uri: str, + table_name: str, + hybrid: bool = False, + mode: str = "overwrite", +) -> Dict[str, Any]: + """ + Handle LanceDB writing for a batch pipeline run. + + This is used by `nemo_retriever.examples.batch_pipeline.run(...)` after the embedding stage. + + Reads `*.text_embeddings.json` files from `input_dir`, extracts embeddings, and uploads to LanceDB. + ) + """ + lancedb_config = LanceDBConfig( + uri=uri, table_name=table_name, hybrid=hybrid + ) # Use the same LanceDB config for writing and recall. + db = lancedb.connect(uri=lancedb_config.uri) + cleaned_rows = _build_lancedb_rows_from_df(rows) + _write_rows_to_lancedb(cleaned_rows, cfg=lancedb_config) + table = db.open_table(lancedb_config.table_name) # Ensure table is open and metadata is updated before proceeding. + create_lancedb_index(table, cfg=lancedb_config)