Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions nemo_retriever/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,24 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv python install 3.12 \
&& uv venv --python 3.12 /opt/retriever_runtime

ENV VIRTUAL_ENV=/opt/retriever_runtime
ENV PATH=/opt/retriever_runtime/bin:/root/.local/bin:$PATH
ENV LD_LIBRARY_PATH=/opt/retriever_runtime/lib:${LD_LIBRARY_PATH}
RUN --mount=type=cache,target=/root/.cache/uv \
. /opt/retriever_runtime/bin/activate \
&& wget -qO- https://bootstrap.pypa.io/get-pip.py | python -

RUN --mount=type=cache,target=/root/.cache/uv \
. /opt/retriever_runtime/bin/activate \
&& pip install --no-cache-dir openai

RUN --mount=type=cache,target=/root/.cache/uv \
. /opt/retriever_runtime/bin/activate \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \
&& dpkg -i cuda-keyring_1.1-1_all.deb \
&& apt update && apt-get --fix-broken install -y && apt-get -y install cuda-toolkit-13-0


# ENV VIRTUAL_ENV=/opt/retriever_runtime
# ENV PATH=/opt/retriever_runtime/bin:/root/.local/bin:$PATH
# ENV LD_LIBRARY_PATH=/opt/retriever_runtime/lib:${LD_LIBRARY_PATH}

# ---------------------------------------------------------------------------
# Install nemo_retriever and path deps (build context = repo root)
Expand All @@ -84,18 +99,17 @@ WORKDIR /workspace
# Unbuffered stdout/stderr so CLI output appears when run without a TTY (e.g. docker run without -it)
ENV PYTHONUNBUFFERED=1

COPY nemo_retriever nemo_retriever
COPY src src
COPY api api
COPY client client
# COPY nemo_retriever nemo_retriever
# COPY src src
# COPY api api
# COPY client client

# Use base stage's venv at /opt/retriever_runtime; install nemo_retriever in editable mode (path deps: ../src, ../api, ../client)
SHELL ["/bin/bash", "-c"]
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/uv \
. /opt/retriever_runtime/bin/activate \
&& uv pip install -e ./nemo_retriever
# SHELL ["/bin/bash", "-c"]
# RUN --mount=type=cache,target=/root/.cache/pip \
# --mount=type=cache,target=/root/.cache/uv \
# . /opt/retriever_runtime/bin/activate \
# && uv pip install -e ./nemo_retriever

# Default: run in-process pipeline (help if no args)
ENTRYPOINT ["/opt/retriever_runtime/bin/python", "-m", "nemo_retriever.examples.inprocess_pipeline"]
CMD ["--help"]
CMD ["/bin/bash"]
103 changes: 52 additions & 51 deletions nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from nemo_retriever.params import IngestExecuteParams
from nemo_retriever.params import IngestorCreateParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params import VdbUploadParams
from nemo_retriever.recall.core import RecallConfig, retrieve_and_score
from nemo_retriever.utils.input_files import resolve_input_patterns

Expand Down Expand Up @@ -192,8 +191,6 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None:

Creates an empty table with the expected schema if it does not exist yet.
"""
from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema

Path(uri).mkdir(parents=True, exist_ok=True)

db = _lancedb().connect(uri)
Expand All @@ -205,11 +202,56 @@ def _ensure_lancedb_table(uri: str, table_name: str) -> None:

import pyarrow as pa # type: ignore

schema = lancedb_schema(2048)
empty = pa.table({f.name: [] for f in schema}, schema=schema)
schema = pa.schema(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something different about this than what is in the new central lancedb_schema function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a ton of places where we reference lancedb. We need to consolidate. I will change for that call to be used in all functions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, agreed. Some of the functions that had been centralized seem to have been re-introduced. Lets make some time to tidy this up soon.

[
pa.field("vector", pa.list_(pa.float32(), 2048)),
pa.field("pdf_page", pa.string()),
pa.field("filename", pa.string()),
pa.field("pdf_basename", pa.string()),
pa.field("page_number", pa.int32()),
pa.field("source", pa.string()),
pa.field("path", pa.string()),
pa.field("text", pa.string()),
pa.field("metadata", pa.string()),
]
)
empty = pa.table(
{
"vector": [],
"pdf_page": [],
"filename": [],
"pdf_basename": [],
"page_number": [],
"source": [],
"path": [],
"text": [],
"metadata": [],
},
schema=schema,
)
db.create_table(table_name, data=empty, schema=schema, mode="create")


def _gold_to_doc_page(golden_key: str) -> tuple[str, str]:
s = str(golden_key)
if "_" not in s:
return s, ""
doc, page = s.rsplit("_", 1)
return doc, page


def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]:

source_id = hit.get("source_id")
page_number = hit.get("page_number")
if not source_id or page_number is None:
return None, float(hit.get("_distance")) if "_distance" in hit else None

key = f"{Path(str(source_id)).stem}_{page_number}"
dist = float(hit["_distance"]) if "_distance" in hit else float(hit["_score"]) if "_score" in hit else None
return key, dist


@app.command()
def main(
ctx: typer.Context,
Expand Down Expand Up @@ -541,17 +583,6 @@ def main(
embed_granularity=embed_granularity,
)
)
.vdb_upload(
VdbUploadParams(
lancedb={
"lancedb_uri": lancedb_uri,
"table_name": LANCEDB_TABLE,
"overwrite": True,
"create_index": True,
"hybrid": hybrid,
}
)
)
)
elif input_type == "html":
ingestor = (
Expand All @@ -567,17 +598,6 @@ def main(
embed_granularity=embed_granularity,
)
)
.vdb_upload(
VdbUploadParams(
lancedb={
"lancedb_uri": lancedb_uri,
"table_name": LANCEDB_TABLE,
"overwrite": True,
"create_index": True,
"hybrid": hybrid,
}
)
)
)
elif input_type == "doc":
ingestor = (
Expand Down Expand Up @@ -632,17 +652,6 @@ def main(
},
)
)
.vdb_upload(
VdbUploadParams(
lancedb={
"lancedb_uri": lancedb_uri,
"table_name": LANCEDB_TABLE,
"overwrite": True,
"create_index": True,
"hybrid": hybrid,
}
)
)
)
else:
ingestor = (
Expand Down Expand Up @@ -698,17 +707,6 @@ def main(
},
)
)
.vdb_upload(
VdbUploadParams(
lancedb={
"lancedb_uri": lancedb_uri,
"table_name": LANCEDB_TABLE,
"overwrite": True,
"create_index": True,
"hybrid": hybrid,
}
)
)
)

logger.info("Running extraction...")
Expand All @@ -725,8 +723,11 @@ def main(
.materialize()
)

if hasattr(ingestor, "_create_lancedb_index"):
ingestor._create_lancedb_index()
# if hasattr(ingestor, "_create_lancedb_index"):
# ingestor._create_lancedb_index()
from nemo_retriever.vector_store.lancedb_store import handle_lancedb

handle_lancedb(ingest_results.take_all(), lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite")

ingest_elapsed_s = time.perf_counter() - ingest_start
rows_processed = _count_materialized_rows(ingest_results)
Expand Down
17 changes: 9 additions & 8 deletions nemo_retriever/src/nemo_retriever/recall/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from __future__ import annotations

import json
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,7 +226,7 @@ def _search_lancedb(
.text(text)
.nprobes(effective_nprobes)
.refine_factor(refine_factor)
.select(["text", "metadata", "source"])
.select(["text", "metadata", "source", "page_number"])
.limit(top_k)
.rerank(RRFReranker())
.to_list()
Expand All @@ -236,7 +236,7 @@ def _search_lancedb(
table.search(q, vector_column_name=vector_column_name)
.nprobes(effective_nprobes)
.refine_factor(refine_factor)
.select(["text", "metadata", "source", "_distance"])
.select(["text", "metadata", "source", "page_number", "_distance"])
.limit(top_k)
.to_list()
)
Expand All @@ -250,12 +250,13 @@ def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]:
for hits in raw_hits:
keys: List[str] = []
for h in hits:
res = json.loads(h["metadata"])
source = json.loads(h["source"])
page_number = h["page_number"]
source = h["source"]
# Prefer explicit `pdf_page` column; fall back to derived form.
if res.get("page_number") is not None and source.get("source_id"):
filename = Path(source["source_id"]).stem
keys.append(filename + "_" + str(res["page_number"]))
# if res.get("page_number") is not None and source.get("source_id"):
if page_number is not None and source:
filename = Path(source).stem
keys.append(f"{filename}_{str(page_number)}")
else:
logger.warning(
"Skipping hit with missing page_number or source_id: metadata=%s source=%s",
Expand Down
51 changes: 44 additions & 7 deletions nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple # noqa: F401
from datetime import timedelta

from nv_ingest_client.util.vdb.lancedb import LanceDB
import pandas as pd
import lancedb

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -117,7 +119,7 @@ def _extract_page_number(meta: Dict[str, Any]) -> int:
return -1


def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]:
def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Transform an embeddings-enriched primitives DataFrame into LanceDB rows.

Expand All @@ -131,7 +133,7 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]:
"""
out: List[Dict[str, Any]] = []

for _, row in df.iterrows():
for row in rows:
meta = row.get("metadata")
if not isinstance(meta, dict):
continue
Expand All @@ -146,9 +148,12 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]:
embedding = list(embedding) # type: ignore[arg-type]
except Exception:
continue

path, source_id = _extract_source_path_and_id(meta)
page_number = _extract_page_number(meta)
meta.pop("embedding", None) # Remove embedding from metadata to save space in LanceDB.
# path, source_id = _extract_source_path_and_id(meta)
path = row.get("path", "")
source_id = meta.get("source_path", path)
# page_number = _extract_page_number(meta)
page_number = row.get("page_number", -1)
p = Path(path) if path else None
filename = p.name if p is not None else ""
pdf_basename = p.stem if p is not None else ""
Expand All @@ -164,8 +169,10 @@ def _build_lancedb_rows_from_df(df: pd.DataFrame) -> List[Dict[str, Any]]:
"filename": filename,
"pdf_basename": pdf_basename,
"page_number": int(page_number),
"source_id": source_id,
"source": source_id,
"path": path,
"text": row.get("text", ""),
"metadata": str(meta),
}
)

Expand Down Expand Up @@ -203,6 +210,9 @@ def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = "
exc_info=True,
)

for index_stub in table.list_indices():
table.wait_for_index([index_stub.name], timeout=timedelta(seconds=600))


def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig) -> None:
if not rows:
Expand Down Expand Up @@ -231,8 +241,10 @@ def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig
pa.field("filename", pa.string()),
pa.field("pdf_basename", pa.string()),
pa.field("page_number", pa.int32()),
pa.field("source_id", pa.string()),
pa.field("source", pa.string()),
pa.field("path", pa.string()),
pa.field("text", pa.string()),
pa.field("metadata", pa.string()),
]
)

Expand Down Expand Up @@ -323,3 +335,28 @@ def write_text_embeddings_dir_to_lancedb(
# "rows_written": len(all_rows),
"lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite},
}


def handle_lancedb(
rows: Path,
uri: str,
table_name: str,
hybrid: bool = False,
mode: str = "overwrite",
) -> Dict[str, Any]:
"""
Handle LanceDB writing for a batch pipeline run.

This is used by `nemo_retriever.examples.batch_pipeline.run(...)` after the embedding stage.

Reads `*.text_embeddings.json` files from `input_dir`, extracts embeddings, and uploads to LanceDB.
)
"""
lancedb_config = LanceDBConfig(
uri=uri, table_name=table_name, hybrid=hybrid
) # Use the same LanceDB config for writing and recall.
db = lancedb.connect(uri=lancedb_config.uri)
cleaned_rows = _build_lancedb_rows_from_df(rows)
_write_rows_to_lancedb(cleaned_rows, cfg=lancedb_config)
table = db.open_table(lancedb_config.table_name) # Ensure table is open and metadata is updated before proceeding.
create_lancedb_index(table, cfg=lancedb_config)
Loading