diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 0c40a9e10..1d98ffb29 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -71,12 +71,7 @@ def _debug_log(*, logger: logging.Logger, location: str, message: str, data: dic logger.debug("%s | %s | %r", location, message, data) -def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: - if params is None: - return model_cls(**kwargs) - if kwargs: - return params.model_copy(update=kwargs) # type: ignore[return-value] - return params +from nemo_retriever.params.utils import coerce_params as _coerce_params class _LanceDBWriteActor: @@ -712,17 +707,10 @@ def embed( "No Ray Dataset to embed. Provide input_dataset or run .files(...) / .extract(...) first." ) - resolved = _coerce_params(params, EmbedParams, kwargs) - kwargs = { - **resolved.model_dump( - mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True - ), - **resolved.runtime.model_dump(mode="python", exclude_none=True), - **resolved.batch_tuning.model_dump(mode="python", exclude_none=True), - } + from nemo_retriever.params.utils import build_embed_kwargs - if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"): - kwargs["embedding_endpoint"] = kwargs.get("embed_invoke_url") + resolved = _coerce_params(params, EmbedParams, kwargs) + kwargs = build_embed_kwargs(resolved, include_batch_tuning=True) # Remaining kwargs are forwarded to the actor constructor. embed_modality = resolved.embed_modality diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 45d6373e7..dd2854ae7 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -64,12 +64,7 @@ _CONTENT_COLUMNS = ("table", "chart", "infographic") -def _coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: - if params is None: - return model_cls(**kwargs) - if kwargs: - return params.model_copy(update=kwargs) # type: ignore[return-value] - return params +from nemo_retriever.params.utils import coerce_params as _coerce_params def _combine_text_with_content(row, text_column, content_columns): @@ -1286,14 +1281,9 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI ) ) - embed_kwargs = { - **resolved.model_dump( - mode="python", exclude={"runtime", "batch_tuning", "fused_tuning"}, exclude_none=True - ), - **resolved.runtime.model_dump(mode="python", exclude_none=True), - } - if "embedding_endpoint" not in embed_kwargs and embed_kwargs.get("embed_invoke_url"): - embed_kwargs["embedding_endpoint"] = embed_kwargs.get("embed_invoke_url") + from nemo_retriever.params.utils import build_embed_kwargs + + embed_kwargs = build_embed_kwargs(resolved) # Ensure embed_modality is forwarded to the embedding function. embed_kwargs["embed_modality"] = embed_modality diff --git a/nemo_retriever/src/nemo_retriever/params/utils.py b/nemo_retriever/src/nemo_retriever/params/utils.py new file mode 100644 index 000000000..8b076e169 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/params/utils.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared parameter coercion and building helpers used by ingest modes.""" + +from __future__ import annotations + +from typing import Any, Dict + + +def coerce_params[T](params: T | None, model_cls: type[T], kwargs: dict[str, Any]) -> T: + """Merge *params* and *kwargs* into an instance of *model_cls*. + + - If *params* is ``None``, construct from *kwargs*. + - If *kwargs* is non-empty, apply them as overrides via ``model_copy``. + - Otherwise return *params* unchanged. + """ + if params is None: + return model_cls(**kwargs) + if kwargs: + return params.model_copy(update=kwargs) # type: ignore[return-value] + return params + + +def build_embed_kwargs(resolved: Any, *, include_batch_tuning: bool = False) -> Dict[str, Any]: + """Flatten an ``EmbedParams`` instance into a dict ready for actor/task kwargs. + + Merges ``runtime`` (always) and optionally ``batch_tuning`` sub-models. + Also normalises ``embed_invoke_url`` → ``embedding_endpoint``. + """ + exclude = {"runtime", "batch_tuning", "fused_tuning"} + kwargs: Dict[str, Any] = { + **resolved.model_dump(mode="python", exclude=exclude, exclude_none=True), + **resolved.runtime.model_dump(mode="python", exclude_none=True), + } + if include_batch_tuning: + kwargs.update(resolved.batch_tuning.model_dump(mode="python", exclude_none=True)) + + if "embedding_endpoint" not in kwargs and kwargs.get("embed_invoke_url"): + kwargs["embedding_endpoint"] = kwargs["embed_invoke_url"] + + return kwargs diff --git a/nemo_retriever/tests/test_params_utils.py b/nemo_retriever/tests/test_params_utils.py new file mode 100644 index 000000000..28e4c80a5 --- /dev/null +++ b/nemo_retriever/tests/test_params_utils.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nemo_retriever.params.utils.""" + +from nemo_retriever.params.models import EmbedParams +from nemo_retriever.params.utils import build_embed_kwargs, coerce_params + + +class TestCoerceParams: + def test_none_params_constructs_from_kwargs(self): + result = coerce_params(None, EmbedParams, {"embed_modality": "image"}) + assert isinstance(result, EmbedParams) + assert result.embed_modality == "image" + + def test_params_without_kwargs_returned_unchanged(self): + original = EmbedParams(embed_modality="text") + result = coerce_params(original, EmbedParams, {}) + assert result is original + + def test_params_with_kwargs_applies_overrides(self): + original = EmbedParams(embed_modality="text") + result = coerce_params(original, EmbedParams, {"embed_modality": "image"}) + assert result.embed_modality == "image" + assert result is not original + + +class TestBuildEmbedKwargs: + def test_normalises_embed_invoke_url(self): + params = EmbedParams(embed_invoke_url="http://nim:8000/v1") + kwargs = build_embed_kwargs(params) + assert kwargs["embedding_endpoint"] == "http://nim:8000/v1" + + def test_does_not_overwrite_existing_embedding_endpoint(self): + params = EmbedParams( + embed_invoke_url="http://old:8000/v1", + ) + kwargs = build_embed_kwargs(params) + assert "embedding_endpoint" in kwargs + + def test_includes_batch_tuning_when_requested(self): + params = EmbedParams() + with_bt = build_embed_kwargs(params, include_batch_tuning=True) + without_bt = build_embed_kwargs(params, include_batch_tuning=False) + # batch_tuning keys should be present when included + assert isinstance(with_bt, dict) + assert isinstance(without_bt, dict) + + def test_excludes_nested_sub_models(self): + params = EmbedParams() + kwargs = build_embed_kwargs(params) + assert "runtime" not in kwargs + assert "batch_tuning" not in kwargs + assert "fused_tuning" not in kwargs