Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion nemo_run/core/execution/lepton.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LeptonExecutor(Executor):
mounts: list[dict[str, Any]] = field(default_factory=list)
lepton_job_dir: str = field(init=False, default="")
custom_spec: dict[str, Any] = field(default_factory=dict)
pre_launch_commands: list[str] = field(default_factory=list) # Custom commands before launch

def stop_job(self, job_id: str):
"""
Expand Down Expand Up @@ -244,8 +245,14 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
if len(name) > 35:
logger.warning("length of name exceeds 35 characters. Shortening...")
name = name[:34]

# Build pre-launch commands section
pre_launch_section = ""
if self.pre_launch_commands:
pre_launch_section = "\n".join(self.pre_launch_commands) + "\n"

launch_script = f"""
wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh
{pre_launch_section}wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh
chmod +x init.sh
source init.sh
ln -s {self.lepton_job_dir}/ /nemo_run
Expand Down
214 changes: 214 additions & 0 deletions test/core/execution/test_lepton.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,217 @@ def test_macro_values(self):
result = executor.macro_values()

assert result is None

def test_pre_launch_commands_initialization(self):
"""Test that pre_launch_commands can be initialized and defaults to empty list."""
# Test default initialization
executor = LeptonExecutor(
container_image="test-image",
nemo_run_dir="/test/path",
)
assert executor.pre_launch_commands == []

# Test initialization with commands
commands = ["echo 'Setting up environment'", "export TEST_VAR=value"]
executor_with_commands = LeptonExecutor(
container_image="test-image",
nemo_run_dir="/test/path",
pre_launch_commands=commands,
)
assert executor_with_commands.pre_launch_commands == commands

def test_launch_script_with_pre_launch_commands(self):
"""Test that pre_launch_commands are correctly included in the launch script."""

# Test without pre_launch_commands
executor = LeptonExecutor(
container_image="test-image",
nemo_run_dir="/test/path",
)

# Test script section generation - empty case
pre_launch_section = ""
if executor.pre_launch_commands:
pre_launch_section = "\n".join(executor.pre_launch_commands) + "\n"
assert pre_launch_section == ""

# Test with pre_launch_commands
commands = ["echo 'Custom setup'", "export MY_VAR=test"]
executor_with_commands = LeptonExecutor(
container_image="test-image",
nemo_run_dir="/test/path",
pre_launch_commands=commands,
)

# Test script section generation - with commands
pre_launch_section_with_commands = ""
if executor_with_commands.pre_launch_commands:
pre_launch_section_with_commands = (
"\n".join(executor_with_commands.pre_launch_commands) + "\n"
)

expected_pre_launch = "echo 'Custom setup'\nexport MY_VAR=test\n"
assert pre_launch_section_with_commands == expected_pre_launch

@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.join")
@patch("nemo_run.core.execution.lepton.logger")
def test_launch_method_comprehensive(
self,
mock_logger,
mock_join,
mock_file,
mock_status,
mock_create_job,
mock_move_data,
mock_validate_mounts,
):
"""Test launch method name validation, pre_launch_commands, and script generation."""
# Setup
executor = LeptonExecutor(
container_image="test-image", nemo_run_dir="/test", pre_launch_commands=["echo setup"]
)
executor.job_dir = executor.lepton_job_dir = "/fake"
mock_join.return_value = "/fake/script.sh"
mock_job = MagicMock()
mock_job.metadata.id_ = "job-id"
mock_create_job.return_value = mock_job
mock_status.return_value = LeptonJobState.Running

# Test name transformation and pre_launch_commands
job_id, status = executor.launch("Test_Job.Name", ["python", "script.py"])
assert job_id == "job-id"

# Verify script content includes pre_launch_commands
handle = mock_file.return_value.__enter__.return_value
written_content = handle.write.call_args[0][0]
assert "echo setup\n" in written_content
assert "python script.py" in written_content

# Test long name truncation
long_name = "a" * 50
executor.launch(long_name, ["cmd"])
mock_logger.warning.assert_called_with(
"length of name exceeds 35 characters. Shortening..."
)

@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.join")
@patch("nemo_run.core.execution.lepton.logger")
def test_launch_error_paths(
self,
mock_logger,
mock_join,
mock_file,
mock_status,
mock_create_job,
mock_move_data,
mock_validate_mounts,
):
"""Test launch method error handling and logging."""
executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path")
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
mock_join.return_value = "/fake/launch_script.sh"

# Test job creation failure
mock_create_job.return_value = None
with pytest.raises(RuntimeError, match="Failed to create Lepton job"):
executor.launch("test", ["cmd"])
mock_logger.info.assert_any_call("Creating distributed workload")

# Test missing job ID
mock_job = MagicMock()
mock_job.metadata.id_ = None
mock_create_job.return_value = mock_job
with pytest.raises(RuntimeError, match="Failed to retrieve job information"):
executor.launch("test", ["cmd"])

# Test status failure
mock_job.metadata.id_ = "job-id"
mock_status.return_value = None
with pytest.raises(RuntimeError, match="Failed to retrieve job status"):
executor.launch("test", ["cmd"])

# Test success path with logging
mock_status.return_value = LeptonJobState.Running
job_id, status = executor.launch("test", ["cmd"])
assert job_id == "job-id"
mock_logger.info.assert_any_call("Copying experiment directory to remote filesystem")

@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.join")
@patch("nemo_run.core.execution.lepton.logger")
def test_launch_long_name_truncation(
self,
mock_logger,
mock_join,
mock_file,
mock_status,
mock_create_job,
mock_move_data,
mock_validate_mounts,
):
"""Test name truncation warning and logic (lines 246-247)."""
executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path")
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
mock_join.return_value = "/fake/launch_script.sh"

mock_job = MagicMock()
mock_job.metadata.id_ = "job-id"
mock_create_job.return_value = mock_job
mock_status.return_value = LeptonJobState.Running

# Test long name triggers warning and truncation
long_name = "a" * 50 # 50 characters, exceeds 35
executor.launch(long_name, ["cmd"])
mock_logger.warning.assert_called_with(
"length of name exceeds 35 characters. Shortening..."
)

@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.join")
def test_launch_prelaunch_commands_join(
self,
mock_join,
mock_file,
mock_status,
mock_create_job,
mock_move_data,
mock_validate_mounts,
):
"""Test pre_launch_commands joining logic (line 252)."""
executor = LeptonExecutor(
container_image="test-image",
nemo_run_dir="/test/path",
pre_launch_commands=["echo setup", "export VAR=1"],
)
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
mock_join.return_value = "/fake/launch_script.sh"

mock_job = MagicMock()
mock_job.metadata.id_ = "job-id"
mock_create_job.return_value = mock_job
mock_status.return_value = LeptonJobState.Running

executor.launch("test", ["cmd"])

# Verify script contains joined pre_launch_commands
handle = mock_file.return_value.__enter__.return_value
written_content = handle.write.call_args[0][0]
assert "echo setup\nexport VAR=1\n" in written_content
Loading