diff --git a/nemo_deploy/llm/inference/inference_base.py b/nemo_deploy/llm/inference/inference_base.py index e1b34aa44..cd207965d 100644 --- a/nemo_deploy/llm/inference/inference_base.py +++ b/nemo_deploy/llm/inference/inference_base.py @@ -39,8 +39,6 @@ from megatron.core.transformer.module import MegatronModule from packaging import version -from nemo_export_deploy_common.import_utils import MISSING_NEMO_MSG, UnavailableError - from .tron_utils import ( DistributedInitConfig, RNGConfig, @@ -63,7 +61,6 @@ HAVE_TRITON = False from .nemo_utils import ( - HAVE_NEMO, MCoreTokenizerWrappper, ckpt_to_context_subdir, ckpt_to_weights_subdir, @@ -186,8 +183,6 @@ def load_nemo_checkpoint_to_tron_model(model: List[MegatronModule], path: Path, path (Path): Path to NeMo checkpoint directory legacy_ckpt (bool): Whether to use legacy checkpoint format """ - if not HAVE_NEMO: - raise UnavailableError(MISSING_NEMO_MSG) weights_dir = ckpt_to_weights_subdir(path, is_saving=False) LOGGER.info(f"Loading NeMo checkpoint from {weights_dir}") @@ -309,9 +304,6 @@ def setup_model_and_tokenizer_for_inference( Raises: ValueError: If checkpoint_path is not a valid NeMo-2.0 checkpoint """ - if not HAVE_NEMO: - raise UnavailableError(MISSING_NEMO_MSG) - checkpoint_path = Path(checkpoint_path) # Load model context for config and tokenizer @@ -478,9 +470,6 @@ def create_mcore_engine( - GPTInferenceWrapper: Inference-wrapped model - Union[MCoreTokenizerWrappper, MegatronTokenizer]: Tokenizer instance """ - if not HAVE_NEMO and model_format == "nemo": - raise UnavailableError(MISSING_NEMO_MSG) - # Default to 1 for any parallelism dimension that's None tensor_model_parallel_size = tensor_model_parallel_size if tensor_model_parallel_size is not None else 1 pipeline_model_parallel_size = pipeline_model_parallel_size if pipeline_model_parallel_size is not None else 1 diff --git a/nemo_deploy/llm/inference/nemo_io.py b/nemo_deploy/llm/inference/nemo_io.py new file mode 100644 index 000000000..7d7c3653f --- /dev/null +++ b/nemo_deploy/llm/inference/nemo_io.py @@ -0,0 +1,415 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IO utilities for loading NeMo 2.0 checkpoints without a direct nemo import. + +Copied from the NeMo project (https://github.com/NVIDIA/NeMo). Static +``from nemo import …`` statements are removed; the logic is otherwise +identical to the upstream sources. When a NeMo checkpoint is actually +loaded at runtime, NeMo classes are imported transitively through +``pydoc.locate`` — NeMo must therefore still be installed to read NeMo +checkpoints. + +Sources + - IOProtocol : nemo/lightning/io/capture.py + - IO helpers, load : nemo/lightning/io/mixin.py + - load_context : nemo/lightning/io/api.py + - Torch-dtype fiddle + registration : nemo/lightning/io/fdl_torch.py +""" + +from __future__ import annotations + +import dataclasses +import functools +import inspect +import json +import logging +import threading +import uuid +from pathlib import Path +from pydoc import locate +from typing import Any, Dict, Generic, List, Optional, Protocol, TypeVar, runtime_checkable + +import fiddle as fdl +import fiddle._src.experimental.dataclasses as fdl_dc +import torch +from cloudpickle import dump +from cloudpickle import load as pickle_load +from fiddle._src.experimental import serialization + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Thread-local storage (mirrors nemo.lightning.io.mixin._thread_local) +# --------------------------------------------------------------------------- + +_thread_local = threading.local() + + +def _set_thread_local_output_dir(path: Path) -> None: + """Set output_dir in our thread-local and in NeMo's (if already imported). + + NeMo classes registered before our first load call will use NeMo's + _io_unflatten_object, which reads from NeMo's own _thread_local. We + mirror the value there so that pickle-based artifacts still resolve + correctly even in that edge case. + """ + _thread_local.output_dir = path + try: + import sys + + nemo_mixin = sys.modules.get("nemo.lightning.io.mixin") + if nemo_mixin is not None and hasattr(nemo_mixin, "_thread_local"): + nemo_mixin._thread_local.output_dir = path + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Register torch dtypes as fiddle constants +# (from nemo.lightning.io.fdl_torch.enable — only register_constant calls +# are needed for deserialization; libcst / codegen parts are omitted) +# --------------------------------------------------------------------------- + +_TORCH_DTYPE_NAMES = [ + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + "float16", + "bfloat16", + "float32", + "float64", + "complex64", + "complex128", +] + + +def _register_torch_dtypes() -> None: + for name in _TORCH_DTYPE_NAMES: + if hasattr(torch, name): + serialization.register_constant("torch", name, compare_by_identity=True) + + +_register_torch_dtypes() + +# --------------------------------------------------------------------------- +# IOProtocol (from nemo.lightning.io.capture) +# --------------------------------------------------------------------------- + +SelfT = TypeVar("SelfT", covariant=True) + + +@runtime_checkable +class IOProtocol(Protocol, Generic[SelfT]): + @property + def __io__(self) -> fdl.Config[SelfT]: ... # noqa: E704 + + +# --------------------------------------------------------------------------- +# IO helper functions (from nemo.lightning.io.mixin) +# --------------------------------------------------------------------------- + + +def _io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: + """Capture __init__ arguments as a plain dict for fdl.Config creation.""" + sig = inspect.signature(init_fn) + bound_args = sig.bind_partial(self, *args, **kwargs) + config_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"} + + to_del: List[str] = [] + for key in config_kwargs: + if isinstance(config_kwargs[key], IOProtocol): + config_kwargs[key] = config_kwargs[key].__io__ + if dataclasses.is_dataclass(config_kwargs[key]): + config_kwargs[key] = fdl_dc.convert_dataclasses_to_configs(config_kwargs[key], allow_post_init=True) + if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": + to_del.append(key) + + for key in to_del: + del config_kwargs[key] + + return config_kwargs + + +def _io_init(self, **kwargs) -> fdl.Config: + """Create an fdl.Config for *self* from captured init kwargs.""" + try: + return fdl.Config(type(self), **kwargs) + except Exception as e: + raise RuntimeError( + f"Error creating fdl.Config for {type(self).__name__}: {e}\nArguments that caused the error: {kwargs}" + ) from e + + +def _io_wrap_init(cls): + """Wrap cls.__init__ to populate __io__ on every instance.""" + original_init = cls.__init__ + + if getattr(cls, "__wrapped_init__", False): + return cls + + @functools.wraps(original_init) + def wrapped_init(self, *args, **kwargs): + if hasattr(self, "io_transform_args"): + cfg_kwargs = self.io_transform_args(original_init, *args, **kwargs) + else: + cfg_kwargs = _io_transform_args(self, original_init, *args, **kwargs) + if hasattr(self, "io_init"): + self.__io__ = self.io_init(**cfg_kwargs) + else: + self.__io__ = _io_init(self, **cfg_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = wrapped_init + cls.__wrapped_init__ = True + return cls + + +def _io_flatten_object(instance): + """Flatten an IOMixin object to a form fiddle can serialize.""" + try: + serialization.dump_json(instance.__io__) + except (serialization.UnserializableValueError, AttributeError) as exc: + if not hasattr(_thread_local, "local_artifacts_dir") or not hasattr(_thread_local, "output_path"): + raise exc + + local_artifact_path = Path(_thread_local.local_artifacts_dir) / f"{uuid.uuid4()}" + output_path = _thread_local.output_path + artifact_path = output_path / local_artifact_path + with open(artifact_path, "wb") as f: + dump(getattr(instance, "__io__", instance), f) + return (str(local_artifact_path),), None + + return instance.__io__.__flatten__() + + +def _io_unflatten_object(values, metadata): + """Unflatten an IOMixin object; load from pickle if it was saved that way.""" + if not hasattr(_thread_local, "output_dir"): + return fdl.Config.__unflatten__(values, metadata) + + output_dir = _thread_local.output_dir + if len(values) == 1: + pickle_path = values[0] + with open(Path(output_dir) / pickle_path, "rb") as f: + return pickle_load(f) + + return fdl.Config.__unflatten__(values, metadata) + + +def _io_path_elements_fn(x): + """Return the path elements for fiddle graph traversal.""" + try: + serialization.dump_json(x.__io__) + except (serialization.UnserializableValueError, AttributeError): + return (serialization.IdentityElement(),) + + return x.__io__.__path_elements__() + + +def _io_register_serialization(cls) -> None: + """Register fiddle traversal functions for *cls* using our _thread_local.""" + serialization.register_node_traverser( + cls, + flatten_fn=_io_flatten_object, + unflatten_fn=_io_unflatten_object, + path_elements_fn=_io_path_elements_fn, + ) + + +def track_io(target, artifacts=None): + """Add fiddle IO functionality to a class or all eligible classes in a module. + + Copied from ``nemo.lightning.io.mixin.track_io``. + """ + import types as _types + + def _add_io_to_class(cls): + if inspect.isclass(cls) and hasattr(cls, "__init__") and not hasattr(cls, "__io__"): + if cls in [str, int, float, tuple, list, dict, bool, type(None)]: + return cls + cls = _io_wrap_init(cls) + _io_register_serialization(cls) + cls.__io_artifacts__ = artifacts or [] + return cls + + def _is_in_module(obj, module): + return obj.__module__ == module.__name__ or obj.__module__.startswith(f"{module.__name__}.") + + def _process_module(module): + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and _is_in_module(obj, module): + setattr(module, name, _add_io_to_class(obj)) + return module + + if isinstance(target, _types.ModuleType): + return _process_module(target) + elif inspect.isclass(target): + return _add_io_to_class(target) + else: + raise TypeError("Target must be a module or a class") + + +def drop_unexpected_params(config: fdl.Config) -> bool: + """Remove deprecated / unexpected parameters from a fiddle Config tree. + + Copied from ``nemo.lightning.io.mixin.drop_unexpected_params``. + """ + updated = False + + def analyze(cfg, prefix: str): + nonlocal updated + if not isinstance(cfg, fdl.Config): + return + signature = inspect.signature(cfg.__fn_or_cls__) + accept_kwargs = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in signature.parameters.values()) + if not accept_kwargs: + to_drop = [p for p in cfg.__arguments__ if p not in signature.parameters] + if to_drop: + updated = True + _logger.warning("Deprecated parameters to drop from %s: %s", prefix, to_drop) + for p in to_drop: + del cfg.__arguments__[p] + for key, value in cfg.__arguments__.items(): + analyze(value, f"{prefix}.{key}") + + analyze(config, "") + return updated + + +def _artifact_transform_load(cfg: fdl.Config, path: Path) -> None: + """Rewrite artifact paths stored in a config to absolute paths. + + Copied from ``nemo.lightning.io.mixin._artifact_transform_load``. + """ + for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): + current_val = getattr(cfg, artifact.attr) + if isinstance(current_val, fdl.Config): + setattr(cfg, artifact.attr, fdl.build(current_val).attr) + continue + if artifact.skip: + continue + current_val = getattr(cfg, artifact.attr) + if current_val is None: + continue + new_val = str(Path(path) / current_val) + setattr(cfg, artifact.attr, new_val) + + for attr in dir(cfg): + try: + child = getattr(cfg, attr) + if isinstance(child, fdl.Config): + _artifact_transform_load(child, path=path) + except (ValueError, AttributeError): + pass + + +# --------------------------------------------------------------------------- +# load (from nemo.lightning.io.mixin) +# --------------------------------------------------------------------------- + + +def load( + path: Path, + output_type: Any = None, + subpath: Optional[str] = None, + build: bool = True, +) -> Any: + """Load a fiddle-serialised NeMo checkpoint context from an ``io.json`` file. + + Copied from ``nemo.lightning.io.mixin.load``. + """ + _path = Path(path) + _set_thread_local_output_dir(_path) + + if _path.is_dir(): + _path = _path / "io.json" + + if not _path.is_file(): + raise FileNotFoundError(f"No such file: '{_path}'") + + if subpath: + subpath = "." + subpath + + # Register / re-register fiddle traversal for every class in the JSON. + # We always re-register (not just when missing) so that our _thread_local + # is used by _io_unflatten_object for classes that NeMo may have already + # registered before this call. + with open(_path) as f: + j = json.load(f) + + for obj, val in j.get("objects", {}).items(): + clss = ".".join([val["type"]["module"], val["type"]["name"]]) + if subpath and "paths" in val: + if all(subpath not in p for p in val["paths"]): + continue + cls_obj = locate(clss) + if cls_obj is None: + continue + if not serialization.find_node_traverser(cls_obj): + track_io(cls_obj) + else: + # Re-register with our traversal so our _thread_local is active. + _io_register_serialization(cls_obj) + + with open(_path, "rb") as f: + json_config = json.loads(f.read()) + + root_key = None + for obj, val in json_config.get("objects", {}).items(): + if "paths" in val and subpath in val["paths"]: + root_key = obj + break + + if subpath and not root_key: + _logger.warning("Could not find %s for %s in %s", subpath, output_type, _path) + + if root_key: + json_config["root"]["key"] = root_key + + config = serialization.Deserialization(json_config).result + _artifact_transform_load(config, path) + drop_unexpected_params(config) + + if not build: + return config + + return fdl.build(config) + + +# --------------------------------------------------------------------------- +# load_context (from nemo.lightning.io.api) +# --------------------------------------------------------------------------- + + +def load_context(path: Path, subpath: Optional[str] = None, build: bool = True) -> Any: + """Load a NeMo TrainerContext (or a subpath of it) from a checkpoint directory. + + Copied from ``nemo.lightning.io.api.load_context``. + """ + if not isinstance(path, Path): + path = Path(path) + + try: + return load(path, subpath=subpath, build=build) + except FileNotFoundError: + # Backwards compatibility: checkpoints without a ``/context`` sub-dir. + if path.parts[-1] == "context": + path = path.parent + else: + path = path / "context" + return load(path, subpath=subpath, build=build) diff --git a/nemo_deploy/llm/inference/nemo_utils.py b/nemo_deploy/llm/inference/nemo_utils.py index 8a75e1ad4..de1463554 100644 --- a/nemo_deploy/llm/inference/nemo_utils.py +++ b/nemo_deploy/llm/inference/nemo_utils.py @@ -13,13 +13,11 @@ # limitations under the License. """NeMo utility code copied from the NeMo project. -Standalone utilities (MCoreTokenizerWrappper, checkpoint path helpers) are -copied directly and have no dependency on the nemo package. - -Complex types that are tightly coupled to NeMo's class hierarchy and -serialization system (GPTConfig, T5Config, io, set_modelopt_spec_if_exists_in_ckpt) -are re-exported here from the nemo package so that inference_base.py and -tron_utils.py do not need to import from nemo directly. +All utilities here are copied directly from NeMo and have no static +dependency on the nemo package. When a NeMo checkpoint is loaded at +runtime, NeMo classes are imported transitively through pydoc.locate +inside nemo_io.load_context — NeMo must therefore still be installed +to read NeMo checkpoints. Sources: - MCoreTokenizerWrappper : nemo/collections/llm/inference/base.py @@ -28,12 +26,37 @@ ckpt_to_context_subdir : nemo/lightning/ckpt_utils.py - ckpt_to_weights_subdir : nemo/lightning/io/pl.py - constants : nemo/lightning/ckpt_utils.py + - set_modelopt_spec_* : nemo/collections/llm/modelopt/model_utils.py + - load_context, io : nemo_io.py (copied from nemo/lightning/io/) """ import inspect +import logging +import types +from functools import partial from pathlib import Path from typing import Any, Union +from .nemo_io import load_context as _load_context + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# io namespace — exposes load_context under the same attribute name that +# inference_base.py uses (io.load_context(...)). +# --------------------------------------------------------------------------- + +io = types.SimpleNamespace(load_context=_load_context) + +# --------------------------------------------------------------------------- +# GPTConfig / T5Config — type stubs used only for annotations. +# The actual runtime objects are NeMo classes deserialized from the +# checkpoint; isinstance() checks use class-name strings instead. +# --------------------------------------------------------------------------- + +GPTConfig = Any +T5Config = Any + # --------------------------------------------------------------------------- # Constants (from nemo.lightning.ckpt_utils) # --------------------------------------------------------------------------- @@ -163,26 +186,62 @@ def pad(self): # --------------------------------------------------------------------------- -# NeMo complex types +# set_modelopt_spec_if_exists_in_ckpt # -# GPTConfig, T5Config, io, and set_modelopt_spec_if_exists_in_ckpt are -# deeply coupled to NeMo's class hierarchy and serialization system. -# Checkpoints saved by NeMo contain instances of these exact classes, so -# they must originate from the nemo package to preserve isinstance() -# compatibility. They are re-exported here so that inference_base.py and -# tron_utils.py do not need to import from nemo directly. +# Copied from nemo/collections/llm/modelopt/model_utils.py. +# NeMo model-type isinstance checks are replaced by class-name checks to +# avoid importing nemo at module level. # --------------------------------------------------------------------------- -try: - from nemo.collections.llm.gpt.model.base import GPTConfig - from nemo.collections.llm.modelopt import set_modelopt_spec_if_exists_in_ckpt - from nemo.collections.llm.t5.model.t5 import T5Config - from nemo.lightning import io - - HAVE_NEMO = True -except (ImportError, ModuleNotFoundError): - GPTConfig = Any - T5Config = Any - io = None - set_modelopt_spec_if_exists_in_ckpt = None - HAVE_NEMO = False + +def set_modelopt_spec_if_exists_in_ckpt(model, path: str) -> None: + """Set model.config.transformer_layer_spec to a modelopt spec if the + checkpoint contains a ``modelopt_state`` directory. + + Copied from ``nemo.collections.llm.modelopt.model_utils.set_modelopt_spec_if_exists_in_ckpt`` + with NeMo isinstance checks replaced by class-name comparisons. + """ + path = str(path).removeprefix("nemo://") + modelopt_state_path = ckpt_to_weights_subdir(path, is_saving=False) / "modelopt_state" + if not modelopt_state_path.exists() or hasattr(model, "module"): + return + + model_type_name = type(model).__name__ + if model_type_name not in ("GPTModel", "MambaModel"): + _logger.warning( + "%s is neither a GPTModel nor MambaModel. Modelopt state will not be loaded.", + type(model), + ) + return + + config = model.config + config_type_name = type(config).__name__ + + try: + from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec + + _HAVE_GPT_MODELOPT_SPEC = True + except ImportError: + _HAVE_GPT_MODELOPT_SPEC = False + + if config_type_name == "GPTConfig": + if _HAVE_GPT_MODELOPT_SPEC: + config.transformer_layer_spec = partial( + get_gpt_modelopt_spec, + remap_te_layernorm=True, + local_core_attention=getattr(config, "softmax_type", "vanilla") != "vanilla", + ) + else: + _logger.warning("get_gpt_modelopt_spec not available; skipping modelopt layer spec.") + elif config_type_name == "SSMConfig": + try: + from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec + + config.mamba_stack_spec = partial(get_mamba_stack_modelopt_spec, remap_te_layernorm=True) + except ImportError: + _logger.warning("get_mamba_stack_modelopt_spec not available; skipping modelopt layer spec.") + else: + _logger.warning("No modelopt layer spec supported for config type %s.", type(config)) + return + + config.gradient_accumulation_fusion = False diff --git a/nemo_export/__init__.py b/nemo_export/__init__.py index 6c51dfc86..f69aa4d1c 100644 --- a/nemo_export/__init__.py +++ b/nemo_export/__init__.py @@ -12,21 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# WAR for trtllm and lightning conflict from nemo_export_deploy_common.package_info import __package_name__, __version__ -try: - from nemo.lightning import io - - HAVE_IO = True -except (ImportError, ModuleNotFoundError): - HAVE_IO = False - __all__ = ["__version__", "__package_name__"] -if HAVE_IO: - __all__ += ["io"] - # Optional convenience imports for TensorRT-LLM classes try: from nemo_export.tensorrt_llm import TensorRTLLM diff --git a/tests/unit_tests/deploy/test_etp_sequence_parallel.py b/tests/unit_tests/deploy/test_etp_sequence_parallel.py index 0024a68a0..732f2d067 100644 --- a/tests/unit_tests/deploy/test_etp_sequence_parallel.py +++ b/tests/unit_tests/deploy/test_etp_sequence_parallel.py @@ -283,7 +283,6 @@ class TestSetupModelETPSequenceParallel(unittest.TestCase): def _common_patches(self): return [ - patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True), patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt"), patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init"), patch("nemo_deploy.llm.inference.inference_base.io", new_callable=MagicMock), @@ -315,7 +314,6 @@ def test_expert_tensor_parallel_size_applied(self): patches = self._common_patches() mocks = [p.start() for p in patches] ( - _have_nemo, _set_modelopt, _torch_dist, mock_io, @@ -352,7 +350,6 @@ def test_sequence_parallel_applied(self): patches = self._common_patches() mocks = [p.start() for p in patches] ( - _have_nemo, _set_modelopt, _torch_dist, mock_io, @@ -391,7 +388,6 @@ def test_sequence_parallel_applied(self): class TestCreateMcoreEngineETPSequenceParallel(unittest.TestCase): """Tests that create_mcore_engine handles ETP/SP defaults and passes them down.""" - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -413,7 +409,6 @@ def test_etp_defaults_to_1_when_none( _, kwargs = mock_setup.call_args assert kwargs["expert_tensor_parallel_size"] == 1 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -435,7 +430,6 @@ def test_sp_defaults_to_1_when_none( _, kwargs = mock_setup.call_args assert kwargs["sequence_parallel"] == 1 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -457,7 +451,6 @@ def test_explicit_etp_passed_through( _, kwargs = mock_setup.call_args assert kwargs["expert_tensor_parallel_size"] == 4 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") diff --git a/tests/unit_tests/deploy/test_inference_base.py b/tests/unit_tests/deploy/test_inference_base.py index 36dd072cf..d80c7dd2c 100644 --- a/tests/unit_tests/deploy/test_inference_base.py +++ b/tests/unit_tests/deploy/test_inference_base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types import unittest from pathlib import Path from unittest.mock import MagicMock, patch @@ -24,14 +25,6 @@ ) from megatron.core.transformer.module import MegatronModule -try: - from nemo.collections.llm.gpt.model.base import GPTConfig - from nemo.collections.llm.inference.base import MCoreTokenizerWrappper - - HAVE_NEMO = True -except (ImportError, ModuleNotFoundError): - HAVE_NEMO = False - from nemo_deploy.llm.inference.inference_base import ( MCoreEngineWithCleanup, _load_dist_shards_into_model, @@ -43,11 +36,10 @@ setup_megatron_model_and_tokenizer_for_inference, setup_model_and_tokenizer_for_inference, ) +from nemo_deploy.llm.inference.nemo_utils import MCoreTokenizerWrappper from nemo_deploy.llm.inference.tron_utils import DistributedInitConfig, RNGConfig -from nemo_export_deploy_common.import_utils import UnavailableError -@pytest.mark.skipif(not HAVE_NEMO, reason="NeMo is not installed") @pytest.mark.run_only_on("GPU") class TestInferenceBase(unittest.TestCase): def setUp(self): @@ -64,7 +56,7 @@ def setUp(self): self.mock_tokenizer.pad = 50257 # Padding token ID # Setup model config - self.model_config = GPTConfig( + self.model_config = types.SimpleNamespace( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, @@ -199,7 +191,6 @@ def test_load_nemo_checkpoint_to_tron_model(self, mock_ckpt_to_weights, mock_loa mock_ckpt_to_weights.assert_called_once_with(self.mock_path, is_saving=False) mock_load_shards.assert_called_once_with(self.mock_model_list, self.mock_weights_dir, False) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -256,7 +247,6 @@ def test_setup_model_and_tokenizer_for_inference( mock_torch_dist_init.assert_called_once() mock_set_modelopt.assert_called_once() - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -366,16 +356,10 @@ def test_mcore_engine_with_cleanup_del(self, mock_cleanup): # Verify cleanup was called mock_cleanup.assert_called_once() - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) def test_create_mcore_engine_unknown_format_raises(self): with self.assertRaises(ValueError): create_mcore_engine(path=self.mock_path, model_format="unknown") - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", False) - def test_create_mcore_engine_unavailable_nemo_raises(self): - with self.assertRaises(UnavailableError): - create_mcore_engine(path=self.mock_path) - @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.load_model_config") @patch("nemo_deploy.llm.inference.inference_base.initialize_megatron_for_inference") @@ -670,19 +654,6 @@ def test_load_dist_shards_into_model_legacy_ckpt(self, mock_sharded_state_dict, call_kwargs.args and StrictHandling.LOG_ALL in call_kwargs.args ) - def test_load_nemo_checkpoint_no_nemo(self): - """Test load_nemo_checkpoint_to_tron_model raises when NeMo is not available.""" - with patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", False): - with self.assertRaises(Exception): - load_nemo_checkpoint_to_tron_model(self.mock_model_list, self.mock_path) - - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", False) - def test_setup_model_and_tokenizer_no_nemo(self): - """Test setup_model_and_tokenizer_for_inference raises UnavailableError when NeMo is absent.""" - with self.assertRaises(UnavailableError): - setup_model_and_tokenizer_for_inference(checkpoint_path=self.mock_path) - - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -728,7 +699,6 @@ def test_setup_model_and_tokenizer_cuda_graphs( self.assertTrue(self.model_config.use_te_rng_tracker) self.assertTrue(self.model_config.inference_rng_tracker) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -769,7 +739,6 @@ def test_setup_model_and_tokenizer_with_gradient_accum_fusion( self.assertFalse(self.model_config.gradient_accumulation_fusion) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -811,7 +780,6 @@ def test_setup_model_and_tokenizer_model_config_kwargs( self.assertEqual(self.model_config.hidden_size, 1024) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext") @patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper") @@ -843,7 +811,6 @@ def test_create_mcore_engine_nemo_format( mock_mcore_engine.assert_called_once() self.assertIsNotNone(engine) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_megatron_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext") @patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper") diff --git a/tests/unit_tests/deploy/test_nemo_io.py b/tests/unit_tests/deploy/test_nemo_io.py new file mode 100644 index 000000000..bc5078e2e --- /dev/null +++ b/tests/unit_tests/deploy/test_nemo_io.py @@ -0,0 +1,856 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import fiddle as fdl +import pytest + +from nemo_deploy.llm.inference.nemo_io import ( + IOProtocol, + _artifact_transform_load, + _io_flatten_object, + _io_init, + _io_path_elements_fn, + _io_register_serialization, + _io_transform_args, + _io_unflatten_object, + _io_wrap_init, + _set_thread_local_output_dir, + _thread_local, + drop_unexpected_params, + load, + load_context, + track_io, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _SimpleClass: + """A minimal class used across multiple tests.""" + + def __init__(self, x=1, y=2): + self.x = x + self.y = y + + +class _ClassWithIO: + """A class that implements IOProtocol (has __io__ attribute).""" + + def __init__(self): + self.__io__ = MagicMock() + + +# --------------------------------------------------------------------------- +# _set_thread_local_output_dir +# --------------------------------------------------------------------------- + + +class TestSetThreadLocalOutputDir: + def test_sets_output_dir_on_thread_local(self): + p = Path("/some/path") + _set_thread_local_output_dir(p) + assert _thread_local.output_dir == p + + def test_mirrors_to_nemo_module_when_present(self): + """When nemo.lightning.io.mixin is already in sys.modules the value should be mirrored.""" + fake_mixin = types.SimpleNamespace(_thread_local=types.SimpleNamespace()) + p = Path("/mirror/path") + with patch.dict(sys.modules, {"nemo.lightning.io.mixin": fake_mixin}): + _set_thread_local_output_dir(p) + assert fake_mixin._thread_local.output_dir == p + + def test_does_not_raise_when_nemo_module_missing(self): + """Should not raise even when nemo is not imported.""" + # Remove only the nemo mixin key if present; don't wipe all of sys.modules + with patch.dict(sys.modules, {"nemo.lightning.io.mixin": None}): + _set_thread_local_output_dir(Path("/no/nemo")) + + def test_suppresses_exception_from_nemo_mirror(self): + """An exception thrown while mirroring must be swallowed.""" + + class _BrokenMixin: + @property + def _thread_local(self): + raise RuntimeError("boom") + + broken_mixin = _BrokenMixin() + with patch.dict(sys.modules, {"nemo.lightning.io.mixin": broken_mixin}): + # Should not raise + _set_thread_local_output_dir(Path("/ok")) + + +# --------------------------------------------------------------------------- +# IOProtocol +# --------------------------------------------------------------------------- + + +class TestIOProtocol: + def test_class_with_io_attr_is_instance(self): + obj = _ClassWithIO() + assert isinstance(obj, IOProtocol) + + def test_class_without_io_attr_is_not_instance(self): + assert not isinstance(_SimpleClass(), IOProtocol) + + +# --------------------------------------------------------------------------- +# _io_transform_args +# --------------------------------------------------------------------------- + + +class TestIoTransformArgs: + def test_captures_kwargs(self): + instance = object.__new__(_SimpleClass) + result = _io_transform_args(instance, _SimpleClass.__init__, x=10, y=20) + assert result == {"x": 10, "y": 20} + + def test_captures_positional_args(self): + instance = object.__new__(_SimpleClass) + result = _io_transform_args(instance, _SimpleClass.__init__, 5, 6) + assert result == {"x": 5, "y": 6} + + def test_replaces_io_protocol_with_io_attr(self): + instance = object.__new__(_SimpleClass) + io_obj = _ClassWithIO() + mock_io = MagicMock(spec=fdl.Config) + io_obj.__io__ = mock_io + result = _io_transform_args(instance, _SimpleClass.__init__, x=io_obj, y=2) + assert result["x"] is mock_io + + def test_removes_default_factory_entries(self): + """Arguments whose value has class name _HAS_DEFAULT_FACTORY_CLASS are dropped.""" + + class _HAS_DEFAULT_FACTORY_CLASS: + pass + + factory_sentinel = _HAS_DEFAULT_FACTORY_CLASS() + instance = object.__new__(_SimpleClass) + result = _io_transform_args(instance, _SimpleClass.__init__, x=factory_sentinel, y=3) + assert "x" not in result + assert result["y"] == 3 + + +# --------------------------------------------------------------------------- +# _io_init +# --------------------------------------------------------------------------- + + +class TestIoInit: + def test_creates_fdl_config_for_class(self): + instance = _SimpleClass(x=7, y=8) + cfg = _io_init(instance, x=7, y=8) + assert isinstance(cfg, fdl.Config) + assert cfg.__fn_or_cls__ is _SimpleClass + + def test_raises_runtime_error_on_invalid_kwarg(self): + instance = _SimpleClass() + with pytest.raises(RuntimeError, match="Error creating fdl.Config"): + _io_init(instance, not_a_valid_param=99) + + +# --------------------------------------------------------------------------- +# _io_wrap_init +# --------------------------------------------------------------------------- + + +class TestIoWrapInit: + def test_wraps_init_and_sets_io(self): + class _Target: + def __init__(self, val=42): + self.val = val + + _io_wrap_init(_Target) + obj = _Target(val=10) + assert hasattr(obj, "__io__") + assert isinstance(obj.__io__, fdl.Config) + assert obj.val == 10 + + def test_idempotent_double_wrap(self): + class _Target2: + def __init__(self, val=1): + self.val = val + + _io_wrap_init(_Target2) + wrapped_once = _Target2.__init__ + _io_wrap_init(_Target2) + # Second call must be a no-op (guard via __wrapped_init__) + assert _Target2.__init__ is wrapped_once + + def test_uses_custom_io_transform_args_when_available(self): + """If the class defines io_transform_args, it should be called instead of the global fn.""" + + class _CustomTransform: + custom_called = False + + def io_transform_args(self, init_fn, *args, **kwargs): + _CustomTransform.custom_called = True + return {} + + def __init__(self): + pass + + _io_wrap_init(_CustomTransform) + _CustomTransform() + assert _CustomTransform.custom_called + + def test_uses_custom_io_init_when_available(self): + """If the class defines io_init, it should be called instead of the global fn.""" + custom_cfg = MagicMock() + + class _CustomInit: + def io_init(self, **kwargs): + return custom_cfg + + def __init__(self): + pass + + _io_wrap_init(_CustomInit) + obj = _CustomInit() + assert obj.__io__ is custom_cfg + + +# --------------------------------------------------------------------------- +# _io_flatten_object +# --------------------------------------------------------------------------- + + +class TestIoFlattenObject: + def test_returns_flatten_result_on_success(self): + """When dump_json succeeds, __flatten__ is called on __io__.""" + mock_io = MagicMock() + mock_io.__flatten__ = MagicMock(return_value=("values", "metadata")) + + instance = MagicMock() + instance.__io__ = mock_io + + with patch("nemo_deploy.llm.inference.nemo_io.serialization.dump_json", return_value=None): + result = _io_flatten_object(instance) + + mock_io.__flatten__.assert_called_once() + assert result == ("values", "metadata") + + def test_falls_back_to_pickle_when_unserializable(self, tmp_path): + """When dump_json raises UnserializableValueError, object is pickled to disk.""" + from fiddle._src.experimental import serialization as _ser + + instance = MagicMock() + instance.__io__ = MagicMock() + + import nemo_deploy.llm.inference.nemo_io as _nemo_io + + _nemo_io._thread_local.local_artifacts_dir = "artifacts" + _nemo_io._thread_local.output_path = tmp_path + + (tmp_path / "artifacts").mkdir(parents=True, exist_ok=True) + + with patch( + "nemo_deploy.llm.inference.nemo_io.serialization.dump_json", + side_effect=_ser.UnserializableValueError("bad"), + ): + result = _io_flatten_object(instance) + + # result should be ((str_path,), None) tuple + assert isinstance(result, tuple) + assert result[1] is None + assert isinstance(result[0], tuple) + assert len(result[0]) == 1 + + # cleanup + del _nemo_io._thread_local.local_artifacts_dir + del _nemo_io._thread_local.output_path + + def test_raises_when_pickle_path_unavailable(self): + """When no thread-local paths are set and dump_json fails, the exception propagates.""" + from fiddle._src.experimental import serialization as _ser + + instance = MagicMock() + instance.__io__ = MagicMock() + + import nemo_deploy.llm.inference.nemo_io as _nemo_io + + # Make sure thread-local does NOT have paths + for attr in ("local_artifacts_dir", "output_path"): + try: + delattr(_nemo_io._thread_local, attr) + except AttributeError: + pass + + with patch( + "nemo_deploy.llm.inference.nemo_io.serialization.dump_json", + side_effect=_ser.UnserializableValueError("bad"), + ): + with pytest.raises(_ser.UnserializableValueError): + _io_flatten_object(instance) + + +# --------------------------------------------------------------------------- +# _io_unflatten_object +# --------------------------------------------------------------------------- + + +class TestIoUnflattenObject: + def test_calls_fdl_config_unflatten_when_no_output_dir(self): + import nemo_deploy.llm.inference.nemo_io as _nemo_io + + try: + delattr(_nemo_io._thread_local, "output_dir") + except AttributeError: + pass + + mock_result = MagicMock() + with patch.object(fdl.Config, "__unflatten__", return_value=mock_result) as mock_unflat: + result = _io_unflatten_object(("v1",), "meta") + + mock_unflat.assert_called_once_with(("v1",), "meta") + assert result is mock_result + + def test_loads_pickle_when_single_value_and_output_dir_set(self, tmp_path): + import nemo_deploy.llm.inference.nemo_io as _nemo_io + + _nemo_io._thread_local.output_dir = tmp_path + + # Write a simple pickle + import pickle + + payload = {"key": "value"} + pickle_file = tmp_path / "artifact.pkl" + with open(pickle_file, "wb") as f: + pickle.dump(payload, f) + + result = _io_unflatten_object(("artifact.pkl",), "meta") + assert result == payload + + del _nemo_io._thread_local.output_dir + + def test_calls_fdl_unflatten_when_multiple_values(self): + import nemo_deploy.llm.inference.nemo_io as _nemo_io + + _nemo_io._thread_local.output_dir = Path("/some/dir") + + mock_result = MagicMock() + with patch.object(fdl.Config, "__unflatten__", return_value=mock_result) as mock_unflat: + result = _io_unflatten_object(("v1", "v2"), "meta") + + mock_unflat.assert_called_once_with(("v1", "v2"), "meta") + assert result is mock_result + + del _nemo_io._thread_local.output_dir + + +# --------------------------------------------------------------------------- +# _io_path_elements_fn +# --------------------------------------------------------------------------- + + +class TestIoPathElementsFn: + def test_returns_identity_element_when_unserializable(self): + from fiddle._src.experimental import serialization as _ser + + instance = MagicMock() + instance.__io__ = MagicMock() + + with patch( + "nemo_deploy.llm.inference.nemo_io.serialization.dump_json", + side_effect=_ser.UnserializableValueError("bad"), + ): + result = _io_path_elements_fn(instance) + + assert len(result) == 1 + assert isinstance(result[0], _ser.IdentityElement) + + def test_returns_path_elements_on_success(self): + mock_io = MagicMock() + expected = ("elem1", "elem2") + mock_io.__path_elements__ = MagicMock(return_value=expected) + + instance = MagicMock() + instance.__io__ = mock_io + + with patch("nemo_deploy.llm.inference.nemo_io.serialization.dump_json", return_value=None): + result = _io_path_elements_fn(instance) + + assert result == expected + + def test_returns_identity_element_when_attribute_error(self): + from fiddle._src.experimental import serialization as _ser + + instance = MagicMock() + instance.__io__ = MagicMock() + + with patch( + "nemo_deploy.llm.inference.nemo_io.serialization.dump_json", + side_effect=AttributeError("no attr"), + ): + result = _io_path_elements_fn(instance) + + assert isinstance(result[0], _ser.IdentityElement) + + +# --------------------------------------------------------------------------- +# _io_register_serialization +# --------------------------------------------------------------------------- + + +class TestIoRegisterSerialization: + def test_registers_node_traverser(self): + class _Reg: + def __init__(self): + pass + + with patch("nemo_deploy.llm.inference.nemo_io.serialization.register_node_traverser") as mock_reg: + _io_register_serialization(_Reg) + mock_reg.assert_called_once() + args, kwargs = mock_reg.call_args + assert args[0] is _Reg + + +# --------------------------------------------------------------------------- +# track_io +# --------------------------------------------------------------------------- + + +class TestTrackIo: + def test_wraps_class(self): + class _Trackable: + def __init__(self, a=1): + self.a = a + + result = track_io(_Trackable) + assert result is _Trackable + assert getattr(_Trackable, "__wrapped_init__", False) + + def test_module_type_processes_members(self): + """track_io on a module should add IO to eligible classes.""" + + class _InModule: + def __init__(self): + pass + + fake_module = types.ModuleType("fake_module") + fake_module.__name__ = "fake_module" + _InModule.__module__ = "fake_module" + fake_module._InModule = _InModule + + result = track_io(fake_module) + assert result is fake_module + + def test_raises_type_error_for_non_class_non_module(self): + with pytest.raises(TypeError, match="module or a class"): + track_io("not_a_class_or_module") + + def test_skips_builtin_types(self): + """Built-in types like str, int should be returned unchanged.""" + + class _BuiltinHolder: + pass + + _BuiltinHolder.__init__ = str.__init__ + # Just verify built-ins are in the exclusion list by calling on str + result = track_io(str) + assert result is str + + def test_sets_io_artifacts(self): + class _WithArtifacts: + def __init__(self): + pass + + artifacts = ["artifact1"] + track_io(_WithArtifacts, artifacts=artifacts) + assert _WithArtifacts.__io_artifacts__ == artifacts + + def test_skips_class_already_having_io(self): + """Classes that already have an __io__ attribute should not be re-wrapped.""" + + class _AlreadyWrapped: + __io__ = MagicMock() + + def __init__(self): + pass + + original_init = _AlreadyWrapped.__init__ + track_io(_AlreadyWrapped) + # __init__ should be unchanged since the class already had __io__ + assert _AlreadyWrapped.__init__ is original_init + + +# --------------------------------------------------------------------------- +# drop_unexpected_params +# --------------------------------------------------------------------------- + + +class TestDropUnexpectedParams: + def test_no_change_when_all_params_valid(self): + cfg = fdl.Config(_SimpleClass, x=1, y=2) + updated = drop_unexpected_params(cfg) + assert not updated + + def test_removes_extra_params(self): + cfg = fdl.Config(_SimpleClass, x=1, y=2) + # Inject a bogus parameter directly + cfg.__arguments__["bogus_param"] = 99 + updated = drop_unexpected_params(cfg) + assert updated + assert "bogus_param" not in cfg.__arguments__ + + def test_no_update_when_class_accepts_kwargs(self): + """Classes accepting **kwargs should never be considered for dropping.""" + + class _AcceptsKwargs: + def __init__(self, **kwargs): + pass + + cfg = fdl.Config(_AcceptsKwargs) + cfg.__arguments__["extra"] = "value" + updated = drop_unexpected_params(cfg) + assert not updated + + def test_returns_false_on_non_config_input(self): + """Passing a non-fdl.Config root should short-circuit without raising.""" + updated = drop_unexpected_params(MagicMock(spec=[])) # not an fdl.Config + assert not updated + + def test_recurses_into_nested_configs(self): + """Nested fdl.Config values should also be cleaned up.""" + + class _Outer: + def __init__(self, inner=None): + self.inner = inner + + cfg_outer = fdl.Config(_Outer, inner=fdl.Config(_SimpleClass, x=1, y=2)) + cfg_outer.__arguments__["inner"].__arguments__["bogus"] = 42 + updated = drop_unexpected_params(cfg_outer) + assert updated + + +# --------------------------------------------------------------------------- +# _artifact_transform_load +# --------------------------------------------------------------------------- + + +class TestArtifactTransformLoad: + def test_rewrites_string_artifact_path(self): + """String artifact values should be rewritten to absolute paths.""" + import dataclasses as _dc + + @_dc.dataclass + class _Artifact: + attr: str + skip: bool = False + + class _FnCls: + __io_artifacts__ = [_Artifact(attr="model_path")] + + def __init__(self, model_path=None): + self.model_path = model_path + + cfg = fdl.Config(_FnCls) + cfg.model_path = "relative/path" + _artifact_transform_load(cfg, Path("/base")) + assert cfg.model_path == str(Path("/base") / "relative/path") + + def test_skips_none_artifact(self): + import dataclasses as _dc + + @_dc.dataclass + class _Artifact: + attr: str + skip: bool = False + + class _FnCls: + __io_artifacts__ = [_Artifact(attr="model_path")] + + def __init__(self, model_path=None): + self.model_path = model_path + + cfg = fdl.Config(_FnCls) + cfg.model_path = None + _artifact_transform_load(cfg, Path("/base")) + assert cfg.model_path is None + + def test_skips_artifact_with_skip_flag(self): + import dataclasses as _dc + + @_dc.dataclass + class _Artifact: + attr: str + skip: bool = True + + class _FnCls: + __io_artifacts__ = [_Artifact(attr="model_path", skip=True)] + + def __init__(self, model_path=None): + self.model_path = model_path + + cfg = fdl.Config(_FnCls) + cfg.model_path = "should/not/be/changed" + _artifact_transform_load(cfg, Path("/base")) + # skip=True means after reading current_val the function hits `continue` + # before rewriting, so model_path remains as-is. + assert cfg.model_path == "should/not/be/changed" + + def test_handles_fdl_config_artifact_value(self): + """When artifact value is itself a fdl.Config, it calls fdl.build on it.""" + import dataclasses as _dc + + @_dc.dataclass + class _Artifact: + attr: str + skip: bool = False + + class _Inner: + def __init__(self): + self.attr = "built_value" + + class _FnCls: + __io_artifacts__ = [_Artifact(attr="nested")] + + def __init__(self, nested=None): + self.nested = nested + + inner_cfg = fdl.Config(_Inner) + cfg = fdl.Config(_FnCls) + cfg.nested = inner_cfg + + with patch("nemo_deploy.llm.inference.nemo_io.fdl.build") as mock_build: + mock_built = MagicMock() + mock_built.attr = "built_value" + mock_build.return_value = mock_built + _artifact_transform_load(cfg, Path("/base")) + + mock_build.assert_called_once_with(inner_cfg) + + +# --------------------------------------------------------------------------- +# load +# --------------------------------------------------------------------------- + + +class TestLoad: + def _make_minimal_json(self): + """Return a minimal io.json payload that load() will parse.""" + return { + "objects": {}, + "root": {"key": None, "type": {"module": "builtins", "name": "dict"}, "args": [], "kwargs": {}}, + } + + def test_raises_file_not_found_for_missing_path(self, tmp_path): + with pytest.raises(FileNotFoundError): + load(tmp_path / "nonexistent.json") + + def test_raises_file_not_found_for_missing_io_json_in_dir(self, tmp_path): + with pytest.raises(FileNotFoundError): + load(tmp_path) + + def test_load_returns_config_when_build_false(self, tmp_path): + io_json = self._make_minimal_json() + io_path = tmp_path / "io.json" + io_path.write_text(json.dumps(io_json)) + + mock_result = MagicMock(spec=fdl.Config) + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_result + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + result = load(io_path, build=False) + + assert result is mock_result + + def test_load_calls_fdl_build_when_build_true(self, tmp_path): + io_json = self._make_minimal_json() + io_path = tmp_path / "io.json" + io_path.write_text(json.dumps(io_json)) + + mock_config = MagicMock(spec=fdl.Config) + mock_built = MagicMock() + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_config + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + with patch("nemo_deploy.llm.inference.nemo_io.fdl.build", return_value=mock_built) as mock_build: + result = load(io_path, build=True) + + mock_build.assert_called_once_with(mock_config) + assert result is mock_built + + def test_load_uses_dir_io_json(self, tmp_path): + """When path is a directory, load should look for io.json inside it.""" + io_json = self._make_minimal_json() + (tmp_path / "io.json").write_text(json.dumps(io_json)) + + mock_result = MagicMock(spec=fdl.Config) + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_result + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + result = load(tmp_path, build=False) + + assert result is mock_result + + def test_load_with_objects_and_subpath_filtering(self, tmp_path): + """When subpath is specified only matching objects get re-registered.""" + io_json = { + "objects": { + "obj1": { + "type": {"module": "builtins", "name": "int"}, + "paths": [".model"], + }, + "obj2": { + "type": {"module": "builtins", "name": "str"}, + "paths": [".other"], + }, + }, + "root": { + "key": None, + "type": {"module": "builtins", "name": "dict"}, + "args": [], + "kwargs": {}, + }, + } + io_path = tmp_path / "io.json" + io_path.write_text(json.dumps(io_json)) + + mock_result = MagicMock(spec=fdl.Config) + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_result + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + with patch("nemo_deploy.llm.inference.nemo_io.locate", return_value=None): + result = load(io_path, subpath="model", build=False) + + assert result is mock_result + + def test_load_track_io_called_when_traverser_missing(self, tmp_path): + """When locate returns a class and no traverser is registered, track_io is called.""" + fake_cls = type("FakeCls", (), {"__init__": lambda self: None}) + io_json = { + "objects": { + "obj1": { + "type": {"module": "some.module", "name": "FakeCls"}, + "paths": [""], + }, + }, + "root": {"key": None, "type": {"module": "builtins", "name": "dict"}, "args": [], "kwargs": {}}, + } + io_path = tmp_path / "io.json" + io_path.write_text(json.dumps(io_json)) + + mock_result = MagicMock(spec=fdl.Config) + with patch("nemo_deploy.llm.inference.nemo_io.locate", return_value=fake_cls): + with patch("nemo_deploy.llm.inference.nemo_io.serialization.find_node_traverser", return_value=None): + with patch("nemo_deploy.llm.inference.nemo_io.track_io") as mock_track: + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_result + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + load(io_path, build=False) + mock_track.assert_called_once_with(fake_cls) + + def test_load_reregisters_when_traverser_already_exists(self, tmp_path): + """When locate returns a class and a traverser exists, _io_register_serialization is called.""" + fake_cls = type("FakeCls2", (), {"__init__": lambda self: None}) + io_json = { + "objects": { + "obj1": { + "type": {"module": "some.module", "name": "FakeCls2"}, + "paths": [""], + }, + }, + "root": {"key": None, "type": {"module": "builtins", "name": "dict"}, "args": [], "kwargs": {}}, + } + io_path = tmp_path / "io.json" + io_path.write_text(json.dumps(io_json)) + + mock_result = MagicMock(spec=fdl.Config) + with patch("nemo_deploy.llm.inference.nemo_io.locate", return_value=fake_cls): + with patch( + "nemo_deploy.llm.inference.nemo_io.serialization.find_node_traverser", + return_value=MagicMock(), + ): + with patch("nemo_deploy.llm.inference.nemo_io._io_register_serialization") as mock_rereg: + with patch("nemo_deploy.llm.inference.nemo_io.serialization.Deserialization") as mock_deser: + mock_deser.return_value.result = mock_result + with patch("nemo_deploy.llm.inference.nemo_io._artifact_transform_load"): + with patch("nemo_deploy.llm.inference.nemo_io.drop_unexpected_params"): + load(io_path, build=False) + mock_rereg.assert_called_once_with(fake_cls) + + +# --------------------------------------------------------------------------- +# load_context +# --------------------------------------------------------------------------- + + +class TestLoadContext: + def test_load_context_delegates_to_load(self, tmp_path): + mock_result = MagicMock() + with patch("nemo_deploy.llm.inference.nemo_io.load", return_value=mock_result) as mock_load: + result = load_context(tmp_path, subpath="model") + mock_load.assert_called_once_with(tmp_path, subpath="model", build=True) + assert result is mock_result + + def test_load_context_accepts_string_path(self, tmp_path): + mock_result = MagicMock() + with patch("nemo_deploy.llm.inference.nemo_io.load", return_value=mock_result): + result = load_context(str(tmp_path)) + assert result is mock_result + + def test_load_context_falls_back_to_context_subdir(self, tmp_path): + """When load raises FileNotFoundError, try appending 'context'.""" + call_count = {"n": 0} + + def _mock_load(path, subpath=None, build=True): + call_count["n"] += 1 + if call_count["n"] == 1: + raise FileNotFoundError("not found") + return "loaded_from_context" + + with patch("nemo_deploy.llm.inference.nemo_io.load", side_effect=_mock_load): + result = load_context(tmp_path) + + assert result == "loaded_from_context" + assert call_count["n"] == 2 + + def test_load_context_strips_context_suffix_on_retry(self, tmp_path): + """When path ends with 'context' and load fails, retry with the parent.""" + context_path = tmp_path / "context" + context_path.mkdir() + + call_paths = [] + + def _mock_load(path, subpath=None, build=True): + call_paths.append(path) + if len(call_paths) == 1: + raise FileNotFoundError("not found") + return "loaded" + + with patch("nemo_deploy.llm.inference.nemo_io.load", side_effect=_mock_load): + result = load_context(context_path) + + assert result == "loaded" + # Second call should use parent (tmp_path), not tmp_path / "context" / "context" + assert call_paths[1] == tmp_path + + def test_load_context_passes_build_flag(self, tmp_path): + mock_result = MagicMock() + with patch("nemo_deploy.llm.inference.nemo_io.load", return_value=mock_result) as mock_load: + load_context(tmp_path, build=False) + mock_load.assert_called_once_with(tmp_path, subpath=None, build=False) diff --git a/tests/unit_tests/deploy/test_nemo_utils.py b/tests/unit_tests/deploy/test_nemo_utils.py index 9289a44e1..f7b3d37bc 100644 --- a/tests/unit_tests/deploy/test_nemo_utils.py +++ b/tests/unit_tests/deploy/test_nemo_utils.py @@ -24,6 +24,7 @@ ckpt_to_dir, ckpt_to_weights_subdir, idempotent_path_append, + set_modelopt_spec_if_exists_in_ckpt, ) @@ -236,3 +237,203 @@ def test_pad_property(self): mock_tok.pad_id = 0 wrapper = MCoreTokenizerWrappper(mock_tok) assert wrapper.pad == 0 + + +# --------------------------------------------------------------------------- +# set_modelopt_spec_if_exists_in_ckpt +# --------------------------------------------------------------------------- + + +class TestSetModeloptSpecIfExistsInCkpt: + """Tests for set_modelopt_spec_if_exists_in_ckpt — no GPU required.""" + + def _make_model(self, type_name: str, config_type_name: str, *, has_module: bool = False): + """Build a MagicMock that looks like a NeMo model with the given type names.""" + model = MagicMock() + type(model).__name__ = type_name + if has_module: + model.module = MagicMock() + else: + del model.module # ensure hasattr returns False + + config = MagicMock() + type(config).__name__ = config_type_name + model.config = config + return model, config + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_returns_early_when_modelopt_state_path_missing(self, mock_ckpt_to_weights): + """When modelopt_state does not exist, the function should return immediately.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = False + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("GPTModel", "GPTConfig") + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + # config should be untouched + config.transformer_layer_spec = MagicMock() + assert not mock_modelopt_path.exists.return_value + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_returns_early_when_model_has_module_attr(self, mock_ckpt_to_weights): + """When model.module exists (DDP wrapper), the function should skip.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + # has_module=True forces early return + model, config = self._make_model("GPTModel", "GPTConfig", has_module=True) + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + # gradient_accumulation_fusion should NOT be set (we returned early) + config.gradient_accumulation_fusion.assert_not_called() if callable( + config.gradient_accumulation_fusion + ) else None + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_logs_warning_for_non_gpt_mamba_model(self, mock_ckpt_to_weights): + """Models that are neither GPTModel nor MambaModel should log a warning and return.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("SomeOtherModel", "GPTConfig") + + with patch("nemo_deploy.llm.inference.nemo_utils._logger") as mock_logger: + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "neither a GPTModel nor MambaModel" in warning_msg + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_gpt_config_with_gpt_modelopt_spec_available(self, mock_ckpt_to_weights): + """GPTModel + GPTConfig: when get_gpt_modelopt_spec is importable, spec is set.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("GPTModel", "GPTConfig") + config.softmax_type = "vanilla" + + mock_spec_fn = MagicMock() + fake_module = MagicMock() + fake_module.get_gpt_modelopt_spec = mock_spec_fn + + with patch.dict( + "sys.modules", + {"megatron.core.post_training.modelopt.gpt.model_specs": fake_module}, + ): + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + assert config.gradient_accumulation_fusion is False + # transformer_layer_spec should be a partial wrapping get_gpt_modelopt_spec + assert config.transformer_layer_spec is not None + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_gpt_config_without_gpt_modelopt_spec(self, mock_ckpt_to_weights): + """GPTModel + GPTConfig: when get_gpt_modelopt_spec is NOT importable, log warning.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("GPTModel", "GPTConfig") + + # Ensure the import fails + with patch.dict("sys.modules", {"megatron.core.post_training.modelopt.gpt.model_specs": None}): + with patch("nemo_deploy.llm.inference.nemo_utils._logger") as mock_logger: + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "get_gpt_modelopt_spec not available" in warning_args[0] + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_ssm_config_with_mamba_spec_available(self, mock_ckpt_to_weights): + """MambaModel + SSMConfig: when get_mamba_stack_modelopt_spec is importable, spec is set.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("MambaModel", "SSMConfig") + + mock_mamba_fn = MagicMock() + fake_mamba_module = MagicMock() + fake_mamba_module.get_mamba_stack_modelopt_spec = mock_mamba_fn + + with patch.dict( + "sys.modules", + {"megatron.core.post_training.modelopt.mamba.model_specs": fake_mamba_module}, + ): + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + assert config.gradient_accumulation_fusion is False + assert config.mamba_stack_spec is not None + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_ssm_config_without_mamba_spec(self, mock_ckpt_to_weights): + """MambaModel + SSMConfig: when get_mamba_stack_modelopt_spec is NOT importable, log warning.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("MambaModel", "SSMConfig") + + with patch.dict("sys.modules", {"megatron.core.post_training.modelopt.mamba.model_specs": None}): + with patch("nemo_deploy.llm.inference.nemo_utils._logger") as mock_logger: + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "get_mamba_stack_modelopt_spec not available" in warning_args[0] + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_other_config_type_logs_warning_and_returns(self, mock_ckpt_to_weights): + """GPTModel with an unrecognised config type should log a warning and not touch gradient_accumulation_fusion.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = True + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("GPTModel", "UnknownConfig") + + with patch("nemo_deploy.llm.inference.nemo_utils._logger") as mock_logger: + set_modelopt_spec_if_exists_in_ckpt(model, "/fake/path") + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "No modelopt layer spec supported" in warning_args[0] + # gradient_accumulation_fusion should NOT have been set to False (early return) + # config is a MagicMock, so we just verify the warning was the only side-effect logged + assert mock_logger.warning.call_count == 1 + + @patch("nemo_deploy.llm.inference.nemo_utils.ckpt_to_weights_subdir") + def test_strips_nemo_prefix_from_path(self, mock_ckpt_to_weights): + """Path starting with 'nemo://' should have that prefix stripped before processing.""" + mock_weights = MagicMock() + mock_modelopt_path = MagicMock() + mock_modelopt_path.exists.return_value = False + mock_weights.__truediv__ = MagicMock(return_value=mock_modelopt_path) + mock_ckpt_to_weights.return_value = mock_weights + + model, config = self._make_model("GPTModel", "GPTConfig") + set_modelopt_spec_if_exists_in_ckpt(model, "nemo:///real/path") + + # Verify ckpt_to_weights_subdir was called with the stripped path + mock_ckpt_to_weights.assert_called_once_with("/real/path", is_saving=False) diff --git a/tests/unit_tests/export/test_nemo_export_init.py b/tests/unit_tests/export/test_nemo_export_init.py new file mode 100644 index 000000000..30843c2fe --- /dev/null +++ b/tests/unit_tests/export/test_nemo_export_init.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for nemo_export/__init__.py. + +The module is kept at 1.67% coverage because it is almost entirely a +conditional import block. These tests exercise: + 1. The always-present __all__ entries (__version__ / __package_name__). + 2. The failure path – TensorRT-LLM not available (normal CPU-only env). + 3. The success path – TensorRT-LLM available (simulated via sys.modules). +""" + +import importlib +import sys +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _reload_nemo_export(): + """Force a fresh import of nemo_export so that module-level code re-runs.""" + mod_name = "nemo_export" + if mod_name in sys.modules: + del sys.modules[mod_name] + return importlib.import_module(mod_name) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestNemoExportInit: + def test_version_in_all(self): + import nemo_export + + assert "__version__" in nemo_export.__all__ + + def test_package_name_in_all(self): + import nemo_export + + assert "__package_name__" in nemo_export.__all__ + + def test_version_is_string(self): + import nemo_export + + assert isinstance(nemo_export.__version__, str) + + def test_package_name_is_string(self): + import nemo_export + + assert isinstance(nemo_export.__package_name__, str) + + def test_tensorrt_llm_not_in_all_when_unavailable(self): + """When TensorRT-LLM is not installed, TensorRTLLM should not appear in __all__.""" + # Ensure the submodules are absent from sys.modules before reload + modules_to_remove = [k for k in sys.modules if k.startswith("nemo_export.tensorrt_llm")] + saved = {k: sys.modules.pop(k) for k in modules_to_remove} + + try: + with patch.dict("sys.modules", {"nemo_export.tensorrt_llm": None, "nemo_export.tensorrt_llm_hf": None}): + mod = _reload_nemo_export() + # Assert inside the context so the module object hasn't been reloaded yet + assert "TensorRTLLM" not in mod.__all__ + assert "TensorRTLLMHF" not in mod.__all__ + finally: + # Restore modules so we don't break subsequent tests + sys.modules.update(saved) + _reload_nemo_export() + + def test_tensorrt_llm_in_all_when_available(self): + """When TensorRT-LLM is importable, TensorRTLLM and TensorRTLLMHF appear in __all__.""" + fake_trtllm = MagicMock() + fake_trtllm_hf = MagicMock() + fake_trtllm_cls = MagicMock() + fake_trtllm_hf_cls = MagicMock() + fake_trtllm.TensorRTLLM = fake_trtllm_cls + fake_trtllm_hf.TensorRTLLMHF = fake_trtllm_hf_cls + + extra_modules = { + "nemo_export.tensorrt_llm": fake_trtllm, + "nemo_export.tensorrt_llm_hf": fake_trtllm_hf, + } + + try: + with patch.dict("sys.modules", extra_modules): + mod = _reload_nemo_export() + # Assertions must be inside the patch.dict context because the + # finally-block reload will mutate the same module object in place. + assert "TensorRTLLM" in mod.__all__ + assert "TensorRTLLMHF" in mod.__all__ + assert mod.TensorRTLLM is fake_trtllm_cls + assert mod.TensorRTLLMHF is fake_trtllm_hf_cls + finally: + _reload_nemo_export() + + def test_module_not_found_is_silently_handled(self): + """ModuleNotFoundError during optional import must not propagate.""" + modules_to_remove = [k for k in sys.modules if k.startswith("nemo_export.tensorrt_llm")] + saved = {k: sys.modules.pop(k) for k in modules_to_remove} + + try: + # Setting to None in sys.modules causes ImportError on import + with patch.dict("sys.modules", {"nemo_export.tensorrt_llm": None}): + mod = _reload_nemo_export() + # Should be importable without raising + assert mod is not None + finally: + sys.modules.update(saved) + _reload_nemo_export() + + def test_all_is_a_list(self): + import nemo_export + + assert isinstance(nemo_export.__all__, list) + + def test_all_contains_at_least_two_entries(self): + """__all__ must always have at least __version__ and __package_name__.""" + import nemo_export + + assert len(nemo_export.__all__) >= 2