Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def list_factories(type_or_namespace: Type | str) -> list[Callable]:


def create_cli(
add_verbose_callback: bool = False, nested_entrypoints_creation: bool = True
add_verbose_callback: bool = False,
nested_entrypoints_creation: bool = True,
) -> Typer:
app: Typer = Typer(pretty_exceptions_enable=False)
entrypoints = metadata.entry_points().select(group="nemo_run.cli")
Expand Down Expand Up @@ -960,14 +961,14 @@ def command(
if default_plugins:
self.plugins = default_plugins

_load_workspace()
if isinstance(fn, LazyEntrypoint):
self.execute_lazy(fn, sys.argv, name)
return

try:
if not is_main:
_load_entrypoints()
_load_workspace()
self.cli_execute(fn, ctx.args, type)
except RunContextError as e:
if not verbose:
Expand Down
61 changes: 54 additions & 7 deletions nemo_run/cli/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,11 @@
self.script = script_path.read_text()
if len(cmd) > 1:
self.import_path = " ".join(cmd[1:])
if cmd[0] in ("nemo", "nemo_run"):
if (
cmd[0] in ("nemo", "nemo_run")
or cmd[0].endswith("/nemo")
or cmd[0].endswith("/nemo_run")
):
self.import_path = " ".join(cmd[1:])

def __call__(self, *args, **kwargs):
Expand All @@ -442,18 +446,61 @@
else:
parts = self.import_path.split(" ")
if parts[0] not in entrypoints:
available_cmds = ", ".join(sorted(entrypoints.keys()))
raise ValueError(
f"Entrypoint {parts[0]} not found. Available entrypoints: {list(entrypoints.keys())}"
f"Entrypoint '{parts[0]}' not found. Available top-level entrypoints: {available_cmds}"
)
output = entrypoints[parts[0]]

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo_run.cli.api
begins an import cycle.

# Re-key the nested entrypoint dict to include 'name' attribute as keys
def rekey_entrypoints(entries):
if not isinstance(entries, dict):
return entries

result = {}
for key, value in entries.items():
result[key] = value
if hasattr(value, "name") and value.name != key:
result[value.name] = value
elif isinstance(value, dict):
result[key] = rekey_entrypoints(value)
return result

# Only rekey if we're dealing with a dictionary
if isinstance(output, dict):
output = rekey_entrypoints(output)

if len(parts) > 1:
for part in parts[1:]:
if part in output:
output = output[part]
# Skip args with - or -- prefix or containing = as they're parameters, not subcommands
if part.startswith("-") or "=" in part:
continue

if isinstance(output, dict):
if part in output:
output = output[part]
else:
# Collect available commands for error message
available_cmds = sorted(output.keys())
raise ValueError(
f"Subcommand '{part}' not found for entrypoint '{parts[0]}'. "
f"Available subcommands: {', '.join(available_cmds)}"
)
else:
# We've reached an entrypoint object but tried to access a subcommand
entrypoint_name = getattr(output, "name", parts[0])
raise ValueError(
f"Entrypoint {self.import_path} not found. Available entrypoints: {list(entrypoints.keys())}"
f"'{entrypoint_name}' is a terminal entrypoint and does not have subcommand '{part}'. "
f"You may have provided an incorrect command structure."
)

# If output is a dict, we need to get the default entrypoint
if isinstance(output, dict):
raise ValueError(
f"Incomplete command: '{self.import_path}'. Please specify a subcommand. "
f"Available subcommands: {', '.join(sorted(output.keys()))}"
)

self._target_fn = output.fn

@property
Expand Down Expand Up @@ -813,8 +860,8 @@
Examples:
# Nested config (model.yaml):
model:
_target_: Model
hidden_size: 256
_target_: Model
hidden_size: 256

# Flat config (model.yaml):
_target_: Model
Expand Down
3 changes: 3 additions & 0 deletions nemo_run/run/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def run(
if plugins:
plugins = [plugins] if not isinstance(plugins, list) else plugins

if getattr(fn_or_script, "is_lazy", False):
fn_or_script = fn_or_script.resolve()

default_name = (
fn_or_script.get_name()
if isinstance(fn_or_script, Script)
Expand Down
Loading