Skip to content
Merged
20 changes: 4 additions & 16 deletions nemo_retriever/src/nemo_retriever/ingest_modes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ def _debug_log(*, logger: logging.Logger, location: str, message: str, data: dic
logger.debug("%s | %s | %r", location, message, data)


def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T:
if params is None:
return model_cls(**kwargs)
if kwargs:
return params.model_copy(update=kwargs) # type: ignore[return-value]
return params
from nemo_retriever.params.utils import coerce_params as _coerce_params


class _LanceDBWriteActor:
Expand Down Expand Up @@ -712,17 +707,10 @@ def embed(
"No Ray Dataset to embed. Provide input_dataset or run .files(...) / .extract(...) first."
)

resolved = _coerce_params(params, EmbedParams, kwargs)
kwargs = {
**resolved.model_dump(
mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True
),
**resolved.runtime.model_dump(mode="python", exclude_none=True),
**resolved.batch_tuning.model_dump(mode="python", exclude_none=True),
}
from nemo_retriever.params.utils import build_embed_kwargs

if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"):
kwargs["embedding_endpoint"] = kwargs.get("embed_invoke_url")
resolved = _coerce_params(params, EmbedParams, kwargs)
kwargs = build_embed_kwargs(resolved, include_batch_tuning=True)

# Remaining kwargs are forwarded to the actor constructor.
embed_modality = resolved.embed_modality
Expand Down
18 changes: 4 additions & 14 deletions nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,7 @@
_CONTENT_COLUMNS = ("table", "chart", "infographic")


def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T:
if params is None:
return model_cls(**kwargs)
if kwargs:
return params.model_copy(update=kwargs) # type: ignore[return-value]
return params
from nemo_retriever.params.utils import coerce_params as _coerce_params


def _combine_text_with_content(row, text_column, content_columns):
Expand Down Expand Up @@ -1286,14 +1281,9 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI
)
)

embed_kwargs = {
**resolved.model_dump(
mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True
),
**resolved.runtime.model_dump(mode="python", exclude_none=True),
}
if "embedding_endpoint" not in embed_kwargs and embed_kwargs.get("embed_invoke_url"):
embed_kwargs["embedding_endpoint"] = embed_kwargs.get("embed_invoke_url")
from nemo_retriever.params.utils import build_embed_kwargs

embed_kwargs = build_embed_kwargs(resolved)

# Ensure embed_modality is forwarded to the embedding function.
embed_kwargs["embed_modality"] = embed_modality
Expand Down
43 changes: 43 additions & 0 deletions nemo_retriever/src/nemo_retriever/params/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Shared parameter coercion and building helpers used by ingest modes."""

from __future__ import annotations

from typing import Any, Dict


def coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T:
"""Merge *params* and *kwargs* into an instance of *model_cls*.

- If *params* is ``None``, construct from *kwargs*.
- If *kwargs* is non-empty, apply them as overrides via ``model_copy``.
- Otherwise return *params* unchanged.
"""
if params is None:
return model_cls(**kwargs)
if kwargs:
return params.model_copy(update=kwargs) # type: ignore[return-value]
return params


def build_embed_kwargs(resolved: Any, *, include_batch_tuning: bool = False) -> Dict[str, Any]:
"""Flatten an ``EmbedParams`` instance into a dict ready for actor/task kwargs.

Merges ``runtime`` (always) and optionally ``batch_tuning`` sub-models.
Also normalises ``embed_invoke_url`` → ``embedding_endpoint``.
"""
exclude = {"runtime", "batch_tuning", "fused_tuning"}
kwargs: Dict[str, Any] = {
**resolved.model_dump(mode="python", exclude=exclude, exclude_none=True),
**resolved.runtime.model_dump(mode="python", exclude_none=True),
}
if include_batch_tuning:
kwargs.update(resolved.batch_tuning.model_dump(mode="python", exclude_none=True))

if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"):
kwargs["embedding_endpoint"] = kwargs["embed_invoke_url"]

return kwargs
55 changes: 55 additions & 0 deletions nemo_retriever/tests/test_params_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Unit tests for nemo_retriever.params.utils."""

from nemo_retriever.params.models import EmbedParams
from nemo_retriever.params.utils import build_embed_kwargs, coerce_params


class TestCoerceParams:
def test_none_params_constructs_from_kwargs(self):
result = coerce_params(None, EmbedParams, {"embed_modality": "image"})
assert isinstance(result, EmbedParams)
assert result.embed_modality == "image"

def test_params_without_kwargs_returned_unchanged(self):
original = EmbedParams(embed_modality="text")
result = coerce_params(original, EmbedParams, {})
assert result is original

def test_params_with_kwargs_applies_overrides(self):
original = EmbedParams(embed_modality="text")
result = coerce_params(original, EmbedParams, {"embed_modality": "image"})
assert result.embed_modality == "image"
assert result is not original


class TestBuildEmbedKwargs:
def test_normalises_embed_invoke_url(self):
params = EmbedParams(embed_invoke_url="http://nim:8000/v1")
kwargs = build_embed_kwargs(params)
assert kwargs["embedding_endpoint"] == "http://nim:8000/v1"

def test_does_not_overwrite_existing_embedding_endpoint(self):
params = EmbedParams(
embed_invoke_url="http://old:8000/v1",
)
kwargs = build_embed_kwargs(params)
assert "embedding_endpoint" in kwargs

def test_includes_batch_tuning_when_requested(self):
params = EmbedParams()
with_bt = build_embed_kwargs(params, include_batch_tuning=True)
without_bt = build_embed_kwargs(params, include_batch_tuning=False)
# batch_tuning keys should be present when included
assert isinstance(with_bt, dict)
assert isinstance(without_bt, dict)

def test_excludes_nested_sub_models(self):
params = EmbedParams()
kwargs = build_embed_kwargs(params)
assert "runtime" not in kwargs
assert "batch_tuning" not in kwargs
assert "fused_tuning" not in kwargs
Loading