diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 294cdf1a..5998b6ab 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -1116,9 +1116,14 @@ def on_interval(self): ) self.forward_port_context.__enter__() - self.ssh_config.add_entry( - metadata.user, "localhost", int(metadata.port), self.tunnel_name - ) + try: + self.ssh_config.add_entry( + metadata.user, "localhost", int(metadata.port), self.tunnel_name + ) + except Exception as e: + self.console.print(f"[bold red]Error adding SSH config entry: {e}") + raise e + self.ssh_entry_added = True with self.console.status("Setting up port forwarding", spinner="dots"): @@ -1141,4 +1146,4 @@ def on_stop(self): def tunnel_name(self) -> str: workspace_name = self.space.name - return ".".join([workspace_name, self.space.name]) + return workspace_name diff --git a/nemo_run/core/tunnel/client.py b/nemo_run/core/tunnel/client.py index 316dbc78..db85c14c 100644 --- a/nemo_run/core/tunnel/client.py +++ b/nemo_run/core/tunnel/client.py @@ -311,28 +311,76 @@ def __init__(self, config_path: Optional[str] = None): def _get_default_config_path(self) -> str: config_path = os.path.expanduser("~/.ssh/config") + return config_path + + def _get_host_config_path(self) -> Optional[str]: + """ + Get the path to the ssh config file for the host if we are running in WSL. + + Returns: + Optional[str]: Path to the host's SSH config file or None if not in WSL. + + Raises: + RuntimeError: If running in WSL but unable to determine the Windows path + due to missing utilities or other errors. + """ + config_path = None # If running in WSL environment, update host's ssh config file instead if os.name == "posix" and "WSL" in os.uname().release: - user_profile = subprocess.run( - ["wslvar", "USERPROFILE"], capture_output=True, text=True, check=False - ).stdout.strip("\n") - home_dir = subprocess.run( - ["wslpath", user_profile], capture_output=True, text=True, check=False - ).stdout.strip("\n") - config_path = (Path(home_dir) / ".ssh/config").as_posix() + # Check if wslvar and wslpath are available + wslvar_exists = shutil.which("wslvar") is not None + wslpath_exists = shutil.which("wslpath") is not None + + if wslvar_exists and wslpath_exists: + # Use WSL utilities to get the host ssh config path + user_profile = subprocess.run( + ["wslvar", "USERPROFILE"], capture_output=True, text=True, check=False + ).stdout.strip("\n") + + if not user_profile: + raise RuntimeError("Failed to get USERPROFILE from wslvar") + + home_dir = subprocess.run( + ["wslpath", user_profile], capture_output=True, text=True, check=False + ).stdout.strip("\n") + + if not home_dir: + raise RuntimeError("Failed to convert USERPROFILE path with wslpath") + + config_path = (Path(home_dir) / ".ssh/config").as_posix() + logger.debug(f"Using Windows SSH config at: {config_path}") + else: + # wslu package not installed, raise error + missing_cmds = [] + if not wslvar_exists: + missing_cmds.append("wslvar") + if not wslpath_exists: + missing_cmds.append("wslpath") + + raise RuntimeError( + f"WSL detected but required utilities ({', '.join(missing_cmds)}) not found. " + "These utilities are part of the wslu package. " + "Example of installation: sudo apt install wslu" + ) return config_path def add_entry(self, user: str, hostname: str, port: int, name: str): + host_config_path = self._get_host_config_path() + if host_config_path: + self._add_entry(user, hostname, port, name, host_config_path) + self._add_entry(user, hostname, port, name, self.config_path) + + def _add_entry(self, user: str, hostname: str, port: int, name: str, config_path: str): host = f"tunnel.{name}" new_config_entry = f"""Host {host} User {user} HostName {hostname} Port {port}""" - if os.path.exists(self.config_path): - with open(self.config_path, "r") as file: + if os.path.exists(config_path): + with open(config_path, "r") as file: lines = file.readlines() # Check if the host is already defined in the config @@ -351,16 +399,22 @@ def add_entry(self, user: str, hostname: str, port: int, name: str): else: # Add new entry lines.append(new_config_entry + "\n") - with open(self.config_path, "w") as file: + with open(config_path, "w") as file: file.writelines(lines) else: - with open(self.config_path, "w") as file: + with open(config_path, "w") as file: file.write(new_config_entry + "\n") def remove_entry(self, name: str): + host_config_path = self._get_host_config_path() + if host_config_path: + self._remove_entry(name, host_config_path) + self._remove_entry(name, self.config_path) + + def _remove_entry(self, name: str, config_path: str): host = f"tunnel.{name}" - if os.path.exists(self.config_path): - with open(self.config_path, "r") as file: + if os.path.exists(config_path): + with open(config_path, "r") as file: lines = file.readlines() start_index = None @@ -376,7 +430,7 @@ def remove_entry(self, name: str): del lines[start_index:end_index] - with open(self.config_path, "w") as file: + with open(config_path, "w") as file: file.writelines(lines) print(f"Removed SSH config entry for {host}.") diff --git a/nemo_run/core/tunnel/server.py b/nemo_run/core/tunnel/server.py index 51fced9b..349f7ce2 100644 --- a/nemo_run/core/tunnel/server.py +++ b/nemo_run/core/tunnel/server.py @@ -106,7 +106,8 @@ def restore(cls, path: Path, tunnel=None) -> "TunnelMetadata": tunnel_file = path / "metadata.json" if tunnel: - data = json.loads(tunnel.run(f"cat {tunnel_file}", hide="out").stdout.strip()) + tunnel_path = tunnel_file.as_posix() + data = json.loads(tunnel.run(f"cat {tunnel_path}", hide="out").stdout.strip()) else: with tunnel_file.open("r") as f: data = json.load(f) diff --git a/nemo_run/devspace/editor.py b/nemo_run/devspace/editor.py index 31ef6ea5..4ca4f201 100644 --- a/nemo_run/devspace/editor.py +++ b/nemo_run/devspace/editor.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import platform import shutil from pathlib import Path @@ -22,11 +23,92 @@ from nemo_run.core.frontend.console.api import CONSOLE +def find_editor_executable(base_executable_name): + """Find the proper executable path for an editor, especially in WSL environments. + + Args: + base_executable_name (str): The base name of the executable (e.g., 'code', 'cursor') + + Returns: + str: The path to the executable + + Raises: + ValueError: If the editor is not supported + EnvironmentError: If the editor is not installed or Windows executable not found in WSL + """ + # Define supported editors + SUPPORTED_EDITORS = { + "code": { + "display_name": "VS Code", + "download_url": "https://code.visualstudio.com/", + "exe_name": "Code.exe", + }, + "cursor": { + "display_name": "Cursor", + "download_url": "https://www.cursor.com/", + "exe_name": "Cursor.exe", + }, + # Add new editors here + } + + # Check if the editor is supported + if base_executable_name not in SUPPORTED_EDITORS: + supported_list = ", ".join(SUPPORTED_EDITORS.keys()) + raise ValueError( + f"Editor '{base_executable_name}' is not supported. " + f"Supported editors are: {supported_list}" + ) + + editor_config = SUPPORTED_EDITORS[base_executable_name] + + # Check if the editor is installed + executable_path = shutil.which(base_executable_name) + if not executable_path: + raise EnvironmentError( + f"{editor_config['display_name']} is not installed. " + f"Please install it from {editor_config['download_url']}" + ) + + # Default editor command is the base executable + editor_cmd = base_executable_name + + # If we're running in WSL, find the Windows executable + if os.name == "posix" and "WSL" in os.uname().release: + # Start from the executable directory + current_path = Path(executable_path).parent + exe_found = False + + # Walk up to 5 levels to find the Windows .exe + for _ in range(5): + potential_exe = current_path / editor_config["exe_name"] + if potential_exe.exists(): + editor_cmd = potential_exe.as_posix().replace(" ", "\\ ") + exe_found = True + break + # Move up one directory + parent_path = current_path.parent + if parent_path == current_path: # Reached root + break + current_path = parent_path + + # Raise an error if we couldn't find the Windows executable in WSL + if not exe_found: + raise EnvironmentError( + f"Running in WSL but couldn't find {editor_config['exe_name']} in the " + f"directory structure. For proper WSL integration, ensure {editor_config['display_name']} " + f"is installed in Windows and properly configured for WSL. " + f"See the documentation for {editor_config['display_name']} WSL integration." + ) + + return editor_cmd + + def launch_editor(tunnel: str, path: str): """Launch a code editor for the specified SSH tunnel. Args: tunnel (str): The name of the SSH tunnel. + path (str): The path to open in the editor. Raises: EnvironmentError: If the specified editor is not installed. @@ -42,26 +124,15 @@ def launch_editor(tunnel: str, path: str): if editor != "none": CONSOLE.rule(f"[bold green]Launching {editor}", characters="*") - if editor == "code": - if not shutil.which("code"): - raise EnvironmentError( - "VS Code is not installed. Please install it from https://code.visualstudio.com/" - ) - - code_cli = "code" - - # If we're running in WSL. Launch code from the executable directly. - # This avoids the code launch script activating the WSL remote extension - # which enables us to specify the ssh tunnel as the remote - if os.name == "posix" and "WSL" in os.uname().release: - code_cli = ( - (Path(shutil.which("code")).parent.parent / "Code.exe") - .as_posix() - .replace(" ", "\\ ") - ) - - cmd = f"{code_cli} --new-window --remote ssh-remote+tunnel.{tunnel} {path}" - CONSOLE.print(cmd) - local.run(f"NEMO_EDITOR=vscode {cmd}") - elif editor == "cursor": - local.run(f"NEMO_EDITOR=cursor cursor --remote ssh-remote+tunnel.{tunnel} {path}") + + # Find the proper executable + editor_cmd = find_editor_executable(editor) + + # Execute the editor command + cmd = f"{editor_cmd} --new-window --remote ssh-remote+tunnel.{tunnel} {path}" + CONSOLE.print(cmd) + + if platform.system() == "Windows": + local.run(f"set NEMO_EDITOR={editor} && {cmd}") + else: + local.run(f"NEMO_EDITOR={editor} {cmd}") diff --git a/test/core/execution/test_slurm.py b/test/core/execution/test_slurm.py index 0a2063bb..8a6ed651 100644 --- a/test/core/execution/test_slurm.py +++ b/test/core/execution/test_slurm.py @@ -277,7 +277,7 @@ def test_init(self, mock_executor, mock_space, mock_srun): assert callback.srun == mock_srun assert callback.space == mock_space assert callback.editor_started is False - assert callback.tunnel_name == "test_space.test_space" + assert callback.tunnel_name == "test_space" def test_on_start_with_srun(self, mock_executor, mock_space, mock_srun): """Test on_start method with srun.""" diff --git a/test/core/tunnel/test_client.py b/test/core/tunnel/test_client.py index b14dfeef..26d5cd7d 100644 --- a/test/core/tunnel/test_client.py +++ b/test/core/tunnel/test_client.py @@ -262,18 +262,156 @@ def test_init_custom_path(self): config_file = SSHConfigFile(config_path="/custom/path") assert config_file.config_path == "/custom/path" - @patch("os.uname") - @patch("subprocess.run") - def test_init_wsl(self, mock_run, mock_uname): + def test_init_wsl(self, tmp_path, monkeypatch): + """Test that both WSL and Windows config files are modified in WSL environments.""" + # Create temporary directories for both WSL and Windows configs + wsl_home = tmp_path / "wsl_home" + win_home = tmp_path / "win_home" + wsl_home.mkdir() + win_home.mkdir() + + # Create .ssh dirs + wsl_ssh_dir = wsl_home / ".ssh" + win_ssh_dir = win_home / ".ssh" + wsl_ssh_dir.mkdir() + win_ssh_dir.mkdir() + + # Set HOME environment variable to control expanduser + monkeypatch.setenv("HOME", str(wsl_home)) + + # Simulate WSL environment + monkeypatch.setattr("os.uname", lambda: MagicMock(release="Microsoft-WSL2")) + + # Mock shutil.which to simulate the WSL utilities being available + def mock_which(cmd): + return "/usr/bin/" + cmd if cmd in ["wslvar", "wslpath"] else None + + monkeypatch.setattr("shutil.which", mock_which) + + # Mock subprocess calls to get Windows paths + def mock_subprocess_run(args, **kwargs): + if args[0] == "wslvar" and args[1] == "USERPROFILE": + return MagicMock(stdout="C:\\Users\\test\n") + elif args[0] == "wslpath": + return MagicMock(stdout=str(win_home) + "\n") + return MagicMock(stdout="", returncode=1) # Default fail case + + monkeypatch.setattr("subprocess.run", mock_subprocess_run) + + # Create config file instance and modify files + config_file = SSHConfigFile() + config_file.add_entry("user", "host", 22, "test") + + # Expected paths for configs + wsl_config_path = wsl_ssh_dir / "config" + win_config_path = win_ssh_dir / "config" + + # Verify both files were created + assert wsl_config_path.exists() + assert win_config_path.exists() + + # Check contents of both files + wsl_content = wsl_config_path.read_text() + win_content = win_config_path.read_text() + + # Both files should have the same entry + expected_entry = """Host tunnel.test + User user + HostName host + Port 22 +""" + assert expected_entry in wsl_content + assert expected_entry in win_content + + # Test removing entry from both files + config_file.remove_entry("test") + + # Check contents after removal + wsl_content = wsl_config_path.read_text() + win_content = win_config_path.read_text() + + assert "Host tunnel.test" not in wsl_content + assert "Host tunnel.test" not in win_content + + def test_init_wsl_missing_utilities(self, tmp_path, monkeypatch): + """Test error handling when WSL utilities are missing.""" + # Create temporary directory + wsl_home = tmp_path / "wsl_home" + wsl_home.mkdir() + ssh_dir = wsl_home / ".ssh" + ssh_dir.mkdir() + + # Set HOME environment variable + monkeypatch.setenv("HOME", str(wsl_home)) + # Simulate WSL environment - mock_uname.return_value.release = "WSL" - mock_run.side_effect = [ - MagicMock(stdout="C:\\Users\\test\n"), - MagicMock(stdout="/mnt/c/Users/test\n"), - ] + monkeypatch.setattr("os.uname", lambda: MagicMock(release="Microsoft-WSL2")) + # Mock shutil.which to simulate the WSL utilities being unavailable + def mock_which(cmd): + return None # No commands available + + monkeypatch.setattr("shutil.which", mock_which) + + # Test that the appropriate error is raised config_file = SSHConfigFile() - assert config_file.config_path == "/mnt/c/Users/test/.ssh/config" + with pytest.raises(RuntimeError) as excinfo: + config_file.add_entry("user", "host", 22, "test") + + # Verify error message contains helpful information + error_msg = str(excinfo.value) + assert "WSL detected" in error_msg + assert "required utilities" in error_msg + assert "wslvar" in error_msg + assert "wslpath" in error_msg + assert "apt install wslu" in error_msg + + def test_init_non_wsl(self, tmp_path, monkeypatch): + """Test that only one config file is modified in non-WSL environments.""" + # Create temporary directory structure + home_dir = tmp_path / "home" + home_dir.mkdir() + ssh_dir = home_dir / ".ssh" + ssh_dir.mkdir() + + # Set HOME environment variable + monkeypatch.setenv("HOME", str(home_dir)) + + # Ensure we're not in WSL mode + monkeypatch.setattr("os.uname", lambda: MagicMock(release="Linux 5.15.0")) + + # Create a spy for subprocess.run to verify it's not called + mock_run = MagicMock(side_effect=Exception("subprocess.run should not be called")) + monkeypatch.setattr("subprocess.run", mock_run) + + # Create config instance and add entry + config_file = SSHConfigFile() + config_file.add_entry("user", "host", 22, "test") + + # Verify local file was created + config_path = ssh_dir / "config" + assert config_path.exists() + + # Check content of local file + content = config_path.read_text() + + expected_entry = """Host tunnel.test + User user + HostName host + Port 22 +""" + assert expected_entry in content + + # Test removing entry + config_file.remove_entry("test") + + # Check content after removal + content = config_path.read_text() + + assert "Host tunnel.test" not in content + + # Verify subprocess.run was not called + assert not mock_run.called @patch("builtins.open", new_callable=mock_open) @patch("os.path.exists", return_value=False) diff --git a/test/devspace/test_editor.py b/test/devspace/test_editor.py new file mode 100644 index 00000000..0aea6508 --- /dev/null +++ b/test/devspace/test_editor.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +import pytest +from pathlib import Path +import shutil +from unittest.mock import MagicMock + +from nemo_run.devspace.editor import find_editor_executable + + +class TestFindEditorExecutable: + def test_unsupported_editor(self): + """Test that unsupported editors raise ValueError.""" + with pytest.raises(ValueError, match="not supported"): + find_editor_executable("unsupported_editor") + + def test_editor_not_installed(self, monkeypatch): + """Test that missing editors raise EnvironmentError.""" + # Monkeypatch shutil.which to return None (simulate editor not found) + monkeypatch.setattr(shutil, "which", lambda x: None) + + with pytest.raises(EnvironmentError, match="is not installed"): + find_editor_executable("code") + + with pytest.raises(EnvironmentError, match="is not installed"): + find_editor_executable("cursor") + + def test_non_wsl_environment(self, tmp_path, monkeypatch): + """Test editor detection in non-WSL environment using real file.""" + # Create a fake editor executable in a temp directory + bin_dir = tmp_path / "bin" + bin_dir.mkdir() + + code_exec = bin_dir / "code" + code_exec.touch(mode=0o755) # Make it executable + + cursor_exec = bin_dir / "cursor" + cursor_exec.touch(mode=0o755) + + # Add our temp directory to PATH + old_path = os.environ.get("PATH", "") + os.environ["PATH"] = f"{bin_dir}:{old_path}" + + try: + # Monkeypatch os.uname to return a non-WSL environment + if hasattr(os, "uname"): # Skip on Windows + monkeypatch.setattr(os, "uname", lambda: MagicMock(release="Linux 5.10.0")) + + # Test with actual executables in path + assert find_editor_executable("code") == "code" + assert find_editor_executable("cursor") == "cursor" + finally: + # Restore PATH + os.environ["PATH"] = old_path + + @pytest.mark.skipif( + platform.system() == "Windows", reason="WSL tests only relevant on Unix systems" + ) + def test_wsl_environment(self, tmp_path, monkeypatch): + """Test editor detection in WSL environment.""" + # Create directory structure with both Linux and "Windows" executables + bin_dir = tmp_path / "bin" + bin_dir.mkdir() + + # Linux executables + code_exec = bin_dir / "code" + code_exec.touch(mode=0o755) + + cursor_exec = bin_dir / "cursor" + cursor_exec.touch(mode=0o755) + + # Windows .exe files at various levels + exe_dir = tmp_path / "winbin" + exe_dir.mkdir() + code_exe = exe_dir / "Code.exe" + code_exe.touch(mode=0o755) + + cursor_exe_dir = tmp_path / "apps" / "cursor" + cursor_exe_dir.mkdir(parents=True) + cursor_exe = cursor_exe_dir / "Cursor.exe" + cursor_exe.touch(mode=0o755) + + # Add our temp directory to PATH + old_path = os.environ.get("PATH", "") + os.environ["PATH"] = f"{bin_dir}:{old_path}" + + try: + # Mock WSL environment + monkeypatch.setattr(os, "name", "posix") + monkeypatch.setattr(os, "uname", lambda: MagicMock(release="Microsoft-WSL")) + + # Test cases with different configurations + + # 1. Case where we find the .exe file at a specific location + with monkeypatch.context() as m: + + def mock_which(cmd): + if cmd == "code": + return str(code_exec) + elif cmd == "cursor": + return str(cursor_exec) + return None + + # Only need to mock exists for the specific paths we want to test + original_exists = Path.exists + + def mock_exists(self): + if ( + self == code_exe + or self.name == "Code.exe" + and str(self).startswith(str(tmp_path)) + ): + return True + if ( + self == cursor_exe + or self.name == "Cursor.exe" + and str(self).startswith(str(tmp_path)) + ): + return True + return original_exists(self) + + m.setattr(shutil, "which", mock_which) + m.setattr(Path, "exists", mock_exists) + + # Test code with .exe available + result = find_editor_executable("code") + assert "Code.exe" in result + + # Test cursor with .exe available + result = find_editor_executable("cursor") + assert "Cursor.exe" in result + + # 2. Case where we don't find the .exe file (should now raise error) + with monkeypatch.context() as m: + + def mock_which(cmd): + if cmd == "code": + return str(code_exec) + elif cmd == "cursor": + return str(cursor_exec) + return None + + # Make exists always return False for .exe files + def mock_exists(self): + if ".exe" in str(self).lower(): + return False + return original_exists(self) + + m.setattr(shutil, "which", mock_which) + m.setattr(Path, "exists", mock_exists) + + # Test code with no .exe available (should now raise error) + with pytest.raises(EnvironmentError, match="Running in WSL but couldn't find"): + find_editor_executable("code") + + # Test cursor with no .exe available (should now raise error) + with pytest.raises(EnvironmentError, match="Running in WSL but couldn't find"): + find_editor_executable("cursor") + + finally: + # Restore PATH + os.environ["PATH"] = old_path