Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(
self._title = title
self._id = id or f"{title}_{int(time.time())}"

base_dir = base_dir or get_nemorun_home()
base_dir = str(base_dir or get_nemorun_home())
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)

self.log_level = log_level
Expand Down Expand Up @@ -963,7 +963,7 @@ def reset(self) -> "Experiment":
self.console.log(
f"[bold magenta]Experiment {self._id} has not run yet, skipping reset..."
)
return
return self

old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
self._id = f"{self._title}_{int(time.time())}"
Expand Down Expand Up @@ -1233,18 +1233,19 @@ def maybe_load_external_main(exp_dir: str):
_LOADED_MAINS.add(main_file)

spec = importlib.util.spec_from_file_location("__external_main__", main_file)
new_main_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(new_main_module)
if spec is not None and spec.loader is not None:
new_main_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(new_main_module)

if "__external_main__" not in sys.modules:
sys.modules["__external_main__"] = new_main_module
else:
external = sys.modules["__external_main__"]
if "__external_main__" not in sys.modules:
sys.modules["__external_main__"] = new_main_module
else:
external = sys.modules["__external_main__"]
for attr in dir(new_main_module):
if not attr.startswith("__"):
setattr(external, attr, getattr(new_main_module, attr))

existing_main = sys.modules["__main__"]
for attr in dir(new_main_module):
if not attr.startswith("__"):
setattr(external, attr, getattr(new_main_module, attr))

existing_main = sys.modules["__main__"]
for attr in dir(new_main_module):
if not attr.startswith("__"):
setattr(existing_main, attr, getattr(new_main_module, attr))
setattr(existing_main, attr, getattr(new_main_module, attr))
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ dev = [
"pytest-mock>=3.14.0",
"ipykernel>=6.29.4",
"ipywidgets>=8.1.2",
"jupyter>=1.1.1"
"jupyter>=1.1.1",
"pytest-cov"
]

lint = [
Expand Down
162 changes: 162 additions & 0 deletions test/cli/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,165 @@ def test_verbose_logging(self, runner, app):
mock_configure.reset_mock()
runner.invoke(app, ["error-command"])
mock_configure.assert_called_once_with(False)


class TestTorchrunAndConfirmation:
"""Test torchrun detection and confirmation behavior."""

@patch("os.environ", {"WORLD_SIZE": "2"})
def test_is_torchrun_true(self):
"""Test that _is_torchrun returns True when WORLD_SIZE > 1."""
from nemo_run.cli.api import _is_torchrun

assert _is_torchrun() is True

@patch("os.environ", {})
def test_is_torchrun_false_no_env(self):
"""Test that _is_torchrun returns False when WORLD_SIZE not in environment."""
from nemo_run.cli.api import _is_torchrun

assert _is_torchrun() is False

@patch("os.environ", {"WORLD_SIZE": "1"})
def test_is_torchrun_false_size_one(self):
"""Test that _is_torchrun returns False when WORLD_SIZE = 1."""
from nemo_run.cli.api import _is_torchrun

assert _is_torchrun() is False

@patch("nemo_run.cli.api._is_torchrun", return_value=True)
def test_should_continue_torchrun(self, mock_torchrun):
"""Test that _should_continue returns True under torchrun."""
ctx = run.cli.RunContext(name="test")
assert ctx._should_continue(False) is True
mock_torchrun.assert_called_once()

@patch("nemo_run.cli.api._is_torchrun", return_value=False)
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", True)
def test_should_continue_global_flag_true(self, mock_torchrun):
"""Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
ctx = run.cli.RunContext(name="test")
assert ctx._should_continue(False) is True
mock_torchrun.assert_called_once()

@patch("nemo_run.cli.api._is_torchrun", return_value=False)
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", False)
def test_should_continue_global_flag_false(self, mock_torchrun):
"""Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
ctx = run.cli.RunContext(name="test")
assert ctx._should_continue(False) is False
mock_torchrun.assert_called_once()

@patch("nemo_run.cli.api._is_torchrun", return_value=False)
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", None)
def test_should_continue_skip_confirmation(self, mock_torchrun):
"""Test that _should_continue respects skip_confirmation parameter."""
ctx = run.cli.RunContext(name="test")
assert ctx._should_continue(True) is True
mock_torchrun.assert_called_once()


class TestRunContextLaunch:
"""Test RunContext.launch method."""

def test_launch_with_dryrun(self):
"""Test launch with dryrun."""
ctx = run.cli.RunContext(name="test_run", dryrun=True)
mock_experiment = Mock(spec=run.Experiment)

ctx.launch(mock_experiment)

mock_experiment.dryrun.assert_called_once()
mock_experiment.run.assert_not_called()

def test_launch_normal(self):
"""Test launch without dryrun."""
ctx = run.cli.RunContext(name="test_run", direct=True, tail_logs=True)
mock_experiment = Mock(spec=run.Experiment)

ctx.launch(mock_experiment)

mock_experiment.run.assert_called_once_with(
sequential=False, detach=False, direct=True, tail_logs=True
)

def test_launch_with_executor(self):
"""Test launch with executor specified."""
ctx = run.cli.RunContext(name="test_run")
ctx.executor = Mock(spec=run.LocalExecutor)
mock_experiment = Mock(spec=run.Experiment)

ctx.launch(mock_experiment)

mock_experiment.run.assert_called_once_with(
sequential=False, detach=False, direct=False, tail_logs=False
)

def test_launch_sequential(self):
"""Test launch with sequential=True."""
ctx = run.cli.RunContext(name="test_run")
# Initialize executor to None explicitly
ctx.executor = None
mock_experiment = Mock(spec=run.Experiment)

ctx.launch(mock_experiment, sequential=True)

mock_experiment.run.assert_called_once_with(
sequential=True, detach=False, direct=True, tail_logs=False
)


class TestParsePrefixedArgs:
"""Test _parse_prefixed_args function."""

def test_parse_prefixed_args_simple(self):
"""Test parsing simple prefixed arguments."""
from nemo_run.cli.api import _parse_prefixed_args

args = ["executor=local", "other=value"]
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")

assert prefix_value == "local"
assert prefix_args == []
assert other_args == ["other=value"]

def test_parse_prefixed_args_with_dot_notation(self):
"""Test parsing prefixed arguments with dot notation."""
from nemo_run.cli.api import _parse_prefixed_args

args = ["executor=local", "executor.gpu=2", "other=value"]
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")

assert prefix_value == "local"
assert prefix_args == ["gpu=2"]
assert other_args == ["other=value"]

def test_parse_prefixed_args_with_brackets(self):
"""Test parsing prefixed arguments with bracket notation."""
from nemo_run.cli.api import _parse_prefixed_args

args = ["plugins=list", "plugins[0].name=test", "other=value"]
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "plugins")

assert prefix_value == "list"
assert prefix_args == ["[0].name=test"]
assert other_args == ["other=value"]

def test_parse_prefixed_args_invalid_format(self):
"""Test parsing prefixed arguments with invalid format."""
from nemo_run.cli.api import _parse_prefixed_args

args = ["executorblah", "other=value"]
with pytest.raises(ValueError, match="Executor overwrites must start with 'executor.'"):
_parse_prefixed_args(args, "executor")

def test_parse_prefixed_args_no_prefix(self):
"""Test parsing when no prefixed arguments are present."""
from nemo_run.cli.api import _parse_prefixed_args

args = ["arg1=value1", "arg2=value2"]
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")

assert prefix_value is None
assert prefix_args == []
assert other_args == ["arg1=value1", "arg2=value2"]
93 changes: 92 additions & 1 deletion test/cli/test_cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,30 @@

import sys
from pathlib import Path
from test.dummy_factory import DummyModel
from typing import Any, Dict, List, Literal, Optional, Type, Union

import pytest

from nemo_run.cli.cli_parser import (
ArgumentParsingError,
ArgumentValueError,
CLIException,
CollectionParseError,
DictParseError,
ListParseError,
LiteralParseError,
OperationError,
ParseError,
PythonicParser,
TypeParser,
TypeParsingError,
UndefinedVariableError,
UnknownTypeError,
parse_cli_args,
parse_value,
)
from nemo_run.config import Config, Partial
from test.dummy_factory import DummyModel


class TestSimpleValueParsing:
Expand Down Expand Up @@ -664,3 +667,91 @@ def func(a: List[Dict[str, Union[int, List[str]]]]):

result = parse_cli_args(func, ["a=[{'x': 1, 'y': ['a', 'b']}, {'z': 2}]"])
assert result.a == [{"x": 1, "y": ["a", "b"]}, {"z": 2}]


class TestCLIException:
"""Test the CLIException class hierarchy."""

def test_cli_exception_base(self):
"""Test the base CLIException class."""
ex = CLIException("Test message", "test_arg", {"key": "value"})
assert "Test message" in str(ex)
assert "test_arg" in str(ex)
assert "{'key': 'value'}" in str(ex)
assert ex.arg == "test_arg"
assert ex.context == {"key": "value"}

def test_user_friendly_message(self):
"""Test the user_friendly_message method."""
ex = CLIException("Test message", "test_arg", {"key": "value"})
friendly = ex.user_friendly_message()
assert "Error processing argument 'test_arg'" in friendly
assert "Test message" in friendly

def test_argument_parsing_error(self):
"""Test ArgumentParsingError."""
ex = ArgumentParsingError("Invalid syntax", "bad=arg", {"line": 10})
assert isinstance(ex, CLIException)
assert "Invalid syntax" in str(ex)

def test_type_parsing_error(self):
"""Test TypeParsingError."""
ex = TypeParsingError("Type mismatch", "arg=value", {"expected": "int"})
assert isinstance(ex, CLIException)
assert "Type mismatch" in str(ex)

def test_operation_error(self):
"""Test OperationError."""
ex = OperationError("Invalid operation", "arg+=value", {"op": "+="})
assert isinstance(ex, CLIException)
assert "Invalid operation" in str(ex)

def test_argument_value_error(self):
"""Test ArgumentValueError."""
ex = ArgumentValueError("Invalid value", "arg=value", {"expected": "option"})
assert isinstance(ex, CLIException)
assert "Invalid value" in str(ex)

def test_undefined_variable_error(self):
"""Test UndefinedVariableError."""
ex = UndefinedVariableError("Variable not defined", "undefined+=1", {})
assert isinstance(ex, CLIException)
assert "Variable not defined" in str(ex)

def test_parse_error(self):
"""Test ParseError."""
ex = ParseError("abc", int, "Cannot convert string to int")
assert isinstance(ex, CLIException)
assert "Failed to parse 'abc' as <class 'int'>" in str(ex)
assert ex.value == "abc"
assert ex.reason == "Cannot convert string to int"

def test_literal_parse_error(self):
"""Test LiteralParseError."""
ex = LiteralParseError("red", Literal, "Expected one of ['blue', 'green']")
assert isinstance(ex, ParseError)
assert "Failed to parse 'red'" in str(ex)

def test_collection_parse_error(self):
"""Test CollectionParseError."""
ex = CollectionParseError("[1,2,", list, "Invalid syntax")
assert isinstance(ex, ParseError)
assert "Failed to parse '[1,2,'" in str(ex)

def test_list_parse_error(self):
"""Test ListParseError."""
ex = ListParseError("[1,2,", list, "Invalid syntax")
assert isinstance(ex, CollectionParseError)
assert "Failed to parse '[1,2,'" in str(ex)

def test_dict_parse_error(self):
"""Test DictParseError."""
ex = DictParseError("{1:2,", dict, "Invalid syntax")
assert isinstance(ex, CollectionParseError)
assert "Failed to parse '{1:2,'" in str(ex)

def test_unknown_type_error(self):
"""Test UnknownTypeError."""
ex = UnknownTypeError("value", str, "Unknown type")
assert isinstance(ex, ParseError)
assert "Failed to parse 'value'" in str(ex)
Loading
Loading