diff --git a/examples/entrypoint/experiment.py b/examples/entrypoint/experiment.py index 875d32b6..3e341957 100644 --- a/examples/entrypoint/experiment.py +++ b/examples/entrypoint/experiment.py @@ -104,4 +104,4 @@ def train_models_experiment( if __name__ == "__main__": - run.cli.main(train_models_experiment) + run.cli.main(train_models_experiment, default_executor=local_executor()) diff --git a/nemo_run/cli/api.py b/nemo_run/cli/api.py index 760f29b5..a738cd96 100644 --- a/nemo_run/cli/api.py +++ b/nemo_run/cli/api.py @@ -1232,10 +1232,28 @@ def parse_fn(self, fn: T, args: List[str], **default_kwargs) -> Partial[T]: Returns: Partial[T]: A Partial object representing the parsed function and arguments. """ - output = LazyEntrypoint(fn, factory=self.factory, yaml=self.yaml, overwrites=args) - out = output.resolve() + lazy = LazyEntrypoint( + fn, + factory=self.factory, + yaml=self.yaml, + overwrites=args, + ) - return out + # Resolve exactly once and always pass the current RunContext + # NOTE: `LazyEntrypoint.resolve` calls `parse_factory` if + # `lazy._factory_` is a string. `parse_cli_args` that follows inside + # `resolve` used to see the **same** string and call `parse_factory` + # a second time. We temporarily clear `_factory_` right after the + # first resolution so that it cannot be triggered again. + + _orig_factory = lazy._factory_ + try: + result = lazy.resolve(ctx=self) + finally: + # Restore for potential further use + lazy._factory_ = _orig_factory + + return result def _parse_partial(self, fn: Callable, args: List[str], **default_args) -> Partial[T]: """ diff --git a/nemo_run/cli/lazy.py b/nemo_run/cli/lazy.py index 8bf00c38..ab303a91 100644 --- a/nemo_run/cli/lazy.py +++ b/nemo_run/cli/lazy.py @@ -7,9 +7,10 @@ import shlex import sys from dataclasses import dataclass, field +import inspect from pathlib import Path from types import ModuleType -from typing import Any, Callable, Iterator +from typing import Any, Callable, Iterator, Optional, TYPE_CHECKING from fiddle import Buildable, daglish from fiddle._src import signatures @@ -19,6 +20,9 @@ from nemo_run.config import Partial +if TYPE_CHECKING: + from nemo_run.cli.cli_parser import RunContext + @contextlib.contextmanager def lazy_imports(fallback_to_lazy: bool = False) -> Iterator[None]: @@ -142,7 +146,7 @@ def __init__( if remaining_overwrites: self._add_overwrite(*remaining_overwrites) - def resolve(self) -> Partial: + def resolve(self, ctx: Optional["RunContext"] = None) -> Partial: from nemo_run.cli.cli_parser import parse_cli_args, parse_factory fn = self._target_ @@ -160,12 +164,26 @@ def resolve(self) -> Partial: if isinstance(fn, LazyTarget): fn = fn.target + _fn = fn + if hasattr(fn, "__fn_or_cls__"): + _fn = fn.__fn_or_cls__ + + sig = inspect.signature(_fn) + param_names = sig.parameters.keys() + dotlist = dictconfig_to_dot_list( _args_to_dictconfig(self._args_), has_factory=self._factory_ is not None ) _args = [f"{name}{op}{value}" for name, op, value in dotlist] - return parse_cli_args(fn, _args) + out = parse_cli_args(fn, _args) + + if "ctx" in param_names: + if not ctx: + raise ValueError("ctx is required for this function") + out.ctx = ctx + + return out def __getattr__(self, item: str) -> "LazyEntrypoint": """