Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/entrypoint/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def train_models_experiment(


if __name__ == "__main__":
run.cli.main(train_models_experiment)
run.cli.main(train_models_experiment, default_executor=local_executor())
24 changes: 21 additions & 3 deletions nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,10 +1232,28 @@ def parse_fn(self, fn: T, args: List[str], **default_kwargs) -> Partial[T]:
Returns:
Partial[T]: A Partial object representing the parsed function and arguments.
"""
output = LazyEntrypoint(fn, factory=self.factory, yaml=self.yaml, overwrites=args)
out = output.resolve()
lazy = LazyEntrypoint(
fn,
factory=self.factory,
yaml=self.yaml,
overwrites=args,
)

return out
# Resolve exactly once and always pass the current RunContext
# NOTE: `LazyEntrypoint.resolve` calls `parse_factory` if
# `lazy._factory_` is a string. `parse_cli_args` that follows inside
# `resolve` used to see the **same** string and call `parse_factory`
# a second time. We temporarily clear `_factory_` right after the
# first resolution so that it cannot be triggered again.

_orig_factory = lazy._factory_
try:
result = lazy.resolve(ctx=self)
finally:
# Restore for potential further use
lazy._factory_ = _orig_factory

return result

def _parse_partial(self, fn: Callable, args: List[str], **default_args) -> Partial[T]:
"""
Expand Down
24 changes: 21 additions & 3 deletions nemo_run/cli/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import shlex
import sys
from dataclasses import dataclass, field
import inspect
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Optional, TYPE_CHECKING

from fiddle import Buildable, daglish
from fiddle._src import signatures
Expand All @@ -19,6 +20,9 @@

from nemo_run.config import Partial

if TYPE_CHECKING:
from nemo_run.cli.cli_parser import RunContext


@contextlib.contextmanager
def lazy_imports(fallback_to_lazy: bool = False) -> Iterator[None]:
Expand Down Expand Up @@ -142,7 +146,7 @@ def __init__(
if remaining_overwrites:
self._add_overwrite(*remaining_overwrites)

def resolve(self) -> Partial:
def resolve(self, ctx: Optional["RunContext"] = None) -> Partial:
from nemo_run.cli.cli_parser import parse_cli_args, parse_factory

fn = self._target_
Expand All @@ -160,12 +164,26 @@ def resolve(self) -> Partial:
if isinstance(fn, LazyTarget):
fn = fn.target

_fn = fn
if hasattr(fn, "__fn_or_cls__"):
_fn = fn.__fn_or_cls__

sig = inspect.signature(_fn)
param_names = sig.parameters.keys()

dotlist = dictconfig_to_dot_list(
_args_to_dictconfig(self._args_), has_factory=self._factory_ is not None
)
_args = [f"{name}{op}{value}" for name, op, value in dotlist]

return parse_cli_args(fn, _args)
out = parse_cli_args(fn, _args)

if "ctx" in param_names:
if not ctx:
raise ValueError("ctx is required for this function")
out.ctx = ctx

return out

def __getattr__(self, item: str) -> "LazyEntrypoint":
"""
Expand Down
Loading