diff --git a/nemo_run/core/execution/lepton.py b/nemo_run/core/execution/lepton.py index f3cd2c92..e3aec777 100644 --- a/nemo_run/core/execution/lepton.py +++ b/nemo_run/core/execution/lepton.py @@ -54,6 +54,7 @@ class LeptonExecutor(Executor): mounts: list[dict[str, Any]] = field(default_factory=list) lepton_job_dir: str = field(init=False, default="") custom_spec: dict[str, Any] = field(default_factory=dict) + pre_launch_commands: list[str] = field(default_factory=list) # Custom commands before launch def stop_job(self, job_id: str): """ @@ -244,8 +245,14 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]: if len(name) > 35: logger.warning("length of name exceeds 35 characters. Shortening...") name = name[:34] + + # Build pre-launch commands section + pre_launch_section = "" + if self.pre_launch_commands: + pre_launch_section = "\n".join(self.pre_launch_commands) + "\n" + launch_script = f""" -wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh +{pre_launch_section}wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh chmod +x init.sh source init.sh ln -s {self.lepton_job_dir}/ /nemo_run diff --git a/test/core/execution/test_lepton.py b/test/core/execution/test_lepton.py index 821c8d91..0ce503f0 100644 --- a/test/core/execution/test_lepton.py +++ b/test/core/execution/test_lepton.py @@ -641,3 +641,217 @@ def test_macro_values(self): result = executor.macro_values() assert result is None + + def test_pre_launch_commands_initialization(self): + """Test that pre_launch_commands can be initialized and defaults to empty list.""" + # Test default initialization + executor = LeptonExecutor( + container_image="test-image", + nemo_run_dir="/test/path", + ) + assert executor.pre_launch_commands == [] + + # Test initialization with commands + commands = ["echo 'Setting up environment'", "export TEST_VAR=value"] + executor_with_commands = LeptonExecutor( + container_image="test-image", + nemo_run_dir="/test/path", + pre_launch_commands=commands, + ) + assert executor_with_commands.pre_launch_commands == commands + + def test_launch_script_with_pre_launch_commands(self): + """Test that pre_launch_commands are correctly included in the launch script.""" + + # Test without pre_launch_commands + executor = LeptonExecutor( + container_image="test-image", + nemo_run_dir="/test/path", + ) + + # Test script section generation - empty case + pre_launch_section = "" + if executor.pre_launch_commands: + pre_launch_section = "\n".join(executor.pre_launch_commands) + "\n" + assert pre_launch_section == "" + + # Test with pre_launch_commands + commands = ["echo 'Custom setup'", "export MY_VAR=test"] + executor_with_commands = LeptonExecutor( + container_image="test-image", + nemo_run_dir="/test/path", + pre_launch_commands=commands, + ) + + # Test script section generation - with commands + pre_launch_section_with_commands = "" + if executor_with_commands.pre_launch_commands: + pre_launch_section_with_commands = ( + "\n".join(executor_with_commands.pre_launch_commands) + "\n" + ) + + expected_pre_launch = "echo 'Custom setup'\nexport MY_VAR=test\n" + assert pre_launch_section_with_commands == expected_pre_launch + + @patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.status") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.join") + @patch("nemo_run.core.execution.lepton.logger") + def test_launch_method_comprehensive( + self, + mock_logger, + mock_join, + mock_file, + mock_status, + mock_create_job, + mock_move_data, + mock_validate_mounts, + ): + """Test launch method name validation, pre_launch_commands, and script generation.""" + # Setup + executor = LeptonExecutor( + container_image="test-image", nemo_run_dir="/test", pre_launch_commands=["echo setup"] + ) + executor.job_dir = executor.lepton_job_dir = "/fake" + mock_join.return_value = "/fake/script.sh" + mock_job = MagicMock() + mock_job.metadata.id_ = "job-id" + mock_create_job.return_value = mock_job + mock_status.return_value = LeptonJobState.Running + + # Test name transformation and pre_launch_commands + job_id, status = executor.launch("Test_Job.Name", ["python", "script.py"]) + assert job_id == "job-id" + + # Verify script content includes pre_launch_commands + handle = mock_file.return_value.__enter__.return_value + written_content = handle.write.call_args[0][0] + assert "echo setup\n" in written_content + assert "python script.py" in written_content + + # Test long name truncation + long_name = "a" * 50 + executor.launch(long_name, ["cmd"]) + mock_logger.warning.assert_called_with( + "length of name exceeds 35 characters. Shortening..." + ) + + @patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.status") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.join") + @patch("nemo_run.core.execution.lepton.logger") + def test_launch_error_paths( + self, + mock_logger, + mock_join, + mock_file, + mock_status, + mock_create_job, + mock_move_data, + mock_validate_mounts, + ): + """Test launch method error handling and logging.""" + executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path") + executor.job_dir = executor.lepton_job_dir = "/fake/dir" + mock_join.return_value = "/fake/launch_script.sh" + + # Test job creation failure + mock_create_job.return_value = None + with pytest.raises(RuntimeError, match="Failed to create Lepton job"): + executor.launch("test", ["cmd"]) + mock_logger.info.assert_any_call("Creating distributed workload") + + # Test missing job ID + mock_job = MagicMock() + mock_job.metadata.id_ = None + mock_create_job.return_value = mock_job + with pytest.raises(RuntimeError, match="Failed to retrieve job information"): + executor.launch("test", ["cmd"]) + + # Test status failure + mock_job.metadata.id_ = "job-id" + mock_status.return_value = None + with pytest.raises(RuntimeError, match="Failed to retrieve job status"): + executor.launch("test", ["cmd"]) + + # Test success path with logging + mock_status.return_value = LeptonJobState.Running + job_id, status = executor.launch("test", ["cmd"]) + assert job_id == "job-id" + mock_logger.info.assert_any_call("Copying experiment directory to remote filesystem") + + @patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.status") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.join") + @patch("nemo_run.core.execution.lepton.logger") + def test_launch_long_name_truncation( + self, + mock_logger, + mock_join, + mock_file, + mock_status, + mock_create_job, + mock_move_data, + mock_validate_mounts, + ): + """Test name truncation warning and logic (lines 246-247).""" + executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path") + executor.job_dir = executor.lepton_job_dir = "/fake/dir" + mock_join.return_value = "/fake/launch_script.sh" + + mock_job = MagicMock() + mock_job.metadata.id_ = "job-id" + mock_create_job.return_value = mock_job + mock_status.return_value = LeptonJobState.Running + + # Test long name triggers warning and truncation + long_name = "a" * 50 # 50 characters, exceeds 35 + executor.launch(long_name, ["cmd"]) + mock_logger.warning.assert_called_with( + "length of name exceeds 35 characters. Shortening..." + ) + + @patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job") + @patch("nemo_run.core.execution.lepton.LeptonExecutor.status") + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.join") + def test_launch_prelaunch_commands_join( + self, + mock_join, + mock_file, + mock_status, + mock_create_job, + mock_move_data, + mock_validate_mounts, + ): + """Test pre_launch_commands joining logic (line 252).""" + executor = LeptonExecutor( + container_image="test-image", + nemo_run_dir="/test/path", + pre_launch_commands=["echo setup", "export VAR=1"], + ) + executor.job_dir = executor.lepton_job_dir = "/fake/dir" + mock_join.return_value = "/fake/launch_script.sh" + + mock_job = MagicMock() + mock_job.metadata.id_ = "job-id" + mock_create_job.return_value = mock_job + mock_status.return_value = LeptonJobState.Running + + executor.launch("test", ["cmd"]) + + # Verify script contains joined pre_launch_commands + handle = mock_file.return_value.__enter__.return_value + written_content = handle.write.call_args[0][0] + assert "echo setup\nexport VAR=1\n" in written_content