Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,9 +1116,14 @@ def on_interval(self):
)
self.forward_port_context.__enter__()

self.ssh_config.add_entry(
metadata.user, "localhost", int(metadata.port), self.tunnel_name
)
try:
self.ssh_config.add_entry(
metadata.user, "localhost", int(metadata.port), self.tunnel_name
)
except Exception as e:
self.console.print(f"[bold red]Error adding SSH config entry: {e}")
raise e

self.ssh_entry_added = True

with self.console.status("Setting up port forwarding", spinner="dots"):
Expand All @@ -1141,4 +1146,4 @@ def on_stop(self):
def tunnel_name(self) -> str:
workspace_name = self.space.name

return ".".join([workspace_name, self.space.name])
return workspace_name
82 changes: 68 additions & 14 deletions nemo_run/core/tunnel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,28 +311,76 @@ def __init__(self, config_path: Optional[str] = None):

def _get_default_config_path(self) -> str:
config_path = os.path.expanduser("~/.ssh/config")
return config_path

def _get_host_config_path(self) -> Optional[str]:
"""
Get the path to the ssh config file for the host if we are running in WSL.

Returns:
Optional[str]: Path to the host's SSH config file or None if not in WSL.

Raises:
RuntimeError: If running in WSL but unable to determine the Windows path
due to missing utilities or other errors.
"""
config_path = None

# If running in WSL environment, update host's ssh config file instead
if os.name == "posix" and "WSL" in os.uname().release:
user_profile = subprocess.run(
["wslvar", "USERPROFILE"], capture_output=True, text=True, check=False
).stdout.strip("\n")
home_dir = subprocess.run(
["wslpath", user_profile], capture_output=True, text=True, check=False
).stdout.strip("\n")
config_path = (Path(home_dir) / ".ssh/config").as_posix()
# Check if wslvar and wslpath are available
wslvar_exists = shutil.which("wslvar") is not None
wslpath_exists = shutil.which("wslpath") is not None

if wslvar_exists and wslpath_exists:
# Use WSL utilities to get the host ssh config path
user_profile = subprocess.run(
["wslvar", "USERPROFILE"], capture_output=True, text=True, check=False
).stdout.strip("\n")

if not user_profile:
raise RuntimeError("Failed to get USERPROFILE from wslvar")

home_dir = subprocess.run(
["wslpath", user_profile], capture_output=True, text=True, check=False
).stdout.strip("\n")

if not home_dir:
raise RuntimeError("Failed to convert USERPROFILE path with wslpath")

config_path = (Path(home_dir) / ".ssh/config").as_posix()
logger.debug(f"Using Windows SSH config at: {config_path}")
else:
# wslu package not installed, raise error
missing_cmds = []
if not wslvar_exists:
missing_cmds.append("wslvar")
if not wslpath_exists:
missing_cmds.append("wslpath")

raise RuntimeError(
f"WSL detected but required utilities ({', '.join(missing_cmds)}) not found. "
"These utilities are part of the wslu package. "
"Example of installation: sudo apt install wslu"
)

return config_path

def add_entry(self, user: str, hostname: str, port: int, name: str):
host_config_path = self._get_host_config_path()
if host_config_path:
self._add_entry(user, hostname, port, name, host_config_path)
self._add_entry(user, hostname, port, name, self.config_path)

def _add_entry(self, user: str, hostname: str, port: int, name: str, config_path: str):
host = f"tunnel.{name}"
new_config_entry = f"""Host {host}
User {user}
HostName {hostname}
Port {port}"""

if os.path.exists(self.config_path):
with open(self.config_path, "r") as file:
if os.path.exists(config_path):
with open(config_path, "r") as file:
lines = file.readlines()

# Check if the host is already defined in the config
Expand All @@ -351,16 +399,22 @@ def add_entry(self, user: str, hostname: str, port: int, name: str):
else: # Add new entry
lines.append(new_config_entry + "\n")

with open(self.config_path, "w") as file:
with open(config_path, "w") as file:
file.writelines(lines)
else:
with open(self.config_path, "w") as file:
with open(config_path, "w") as file:
file.write(new_config_entry + "\n")

def remove_entry(self, name: str):
host_config_path = self._get_host_config_path()
if host_config_path:
self._remove_entry(name, host_config_path)
self._remove_entry(name, self.config_path)

def _remove_entry(self, name: str, config_path: str):
host = f"tunnel.{name}"
if os.path.exists(self.config_path):
with open(self.config_path, "r") as file:
if os.path.exists(config_path):
with open(config_path, "r") as file:
lines = file.readlines()

start_index = None
Expand All @@ -376,7 +430,7 @@ def remove_entry(self, name: str):

del lines[start_index:end_index]

with open(self.config_path, "w") as file:
with open(config_path, "w") as file:
file.writelines(lines)

print(f"Removed SSH config entry for {host}.")
Expand Down
3 changes: 2 additions & 1 deletion nemo_run/core/tunnel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def restore(cls, path: Path, tunnel=None) -> "TunnelMetadata":
tunnel_file = path / "metadata.json"

if tunnel:
data = json.loads(tunnel.run(f"cat {tunnel_file}", hide="out").stdout.strip())
tunnel_path = tunnel_file.as_posix()
data = json.loads(tunnel.run(f"cat {tunnel_path}", hide="out").stdout.strip())
else:
with tunnel_file.open("r") as f:
data = json.load(f)
Expand Down
117 changes: 94 additions & 23 deletions nemo_run/devspace/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import os
import platform
import shutil
from pathlib import Path

Expand All @@ -22,11 +23,92 @@
from nemo_run.core.frontend.console.api import CONSOLE


def find_editor_executable(base_executable_name):
"""Find the proper executable path for an editor, especially in WSL environments.

Args:
base_executable_name (str): The base name of the executable (e.g., 'code', 'cursor')

Returns:
str: The path to the executable

Raises:
ValueError: If the editor is not supported
EnvironmentError: If the editor is not installed or Windows executable not found in WSL
"""
# Define supported editors
SUPPORTED_EDITORS = {
"code": {
"display_name": "VS Code",
"download_url": "https://code.visualstudio.com/",
"exe_name": "Code.exe",
},
"cursor": {
"display_name": "Cursor",
"download_url": "https://www.cursor.com/",
"exe_name": "Cursor.exe",
},
# Add new editors here
}

# Check if the editor is supported
if base_executable_name not in SUPPORTED_EDITORS:
supported_list = ", ".join(SUPPORTED_EDITORS.keys())
raise ValueError(
f"Editor '{base_executable_name}' is not supported. "
f"Supported editors are: {supported_list}"
)

editor_config = SUPPORTED_EDITORS[base_executable_name]

# Check if the editor is installed
executable_path = shutil.which(base_executable_name)
if not executable_path:
raise EnvironmentError(
f"{editor_config['display_name']} is not installed. "
f"Please install it from {editor_config['download_url']}"
)

# Default editor command is the base executable
editor_cmd = base_executable_name

# If we're running in WSL, find the Windows executable
if os.name == "posix" and "WSL" in os.uname().release:
# Start from the executable directory
current_path = Path(executable_path).parent
exe_found = False

# Walk up to 5 levels to find the Windows .exe
for _ in range(5):
potential_exe = current_path / editor_config["exe_name"]
if potential_exe.exists():
editor_cmd = potential_exe.as_posix().replace(" ", "\\ ")
exe_found = True
break
# Move up one directory
parent_path = current_path.parent
if parent_path == current_path: # Reached root
break
current_path = parent_path

# Raise an error if we couldn't find the Windows executable in WSL
if not exe_found:
raise EnvironmentError(
f"Running in WSL but couldn't find {editor_config['exe_name']} in the "
f"directory structure. For proper WSL integration, ensure {editor_config['display_name']} "
f"is installed in Windows and properly configured for WSL. "
f"See the documentation for {editor_config['display_name']} WSL integration."
)

return editor_cmd


def launch_editor(tunnel: str, path: str):
"""Launch a code editor for the specified SSH tunnel.

Args:
tunnel (str): The name of the SSH tunnel.
path (str): The path to open in the editor.

Raises:
EnvironmentError: If the specified editor is not installed.
Expand All @@ -42,26 +124,15 @@ def launch_editor(tunnel: str, path: str):

if editor != "none":
CONSOLE.rule(f"[bold green]Launching {editor}", characters="*")
if editor == "code":
if not shutil.which("code"):
raise EnvironmentError(
"VS Code is not installed. Please install it from https://code.visualstudio.com/"
)

code_cli = "code"

# If we're running in WSL. Launch code from the executable directly.
# This avoids the code launch script activating the WSL remote extension
# which enables us to specify the ssh tunnel as the remote
if os.name == "posix" and "WSL" in os.uname().release:
code_cli = (
(Path(shutil.which("code")).parent.parent / "Code.exe")
.as_posix()
.replace(" ", "\\ ")
)

cmd = f"{code_cli} --new-window --remote ssh-remote+tunnel.{tunnel} {path}"
CONSOLE.print(cmd)
local.run(f"NEMO_EDITOR=vscode {cmd}")
elif editor == "cursor":
local.run(f"NEMO_EDITOR=cursor cursor --remote ssh-remote+tunnel.{tunnel} {path}")

# Find the proper executable
editor_cmd = find_editor_executable(editor)

# Execute the editor command
cmd = f"{editor_cmd} --new-window --remote ssh-remote+tunnel.{tunnel} {path}"
CONSOLE.print(cmd)

if platform.system() == "Windows":
local.run(f"set NEMO_EDITOR={editor} && {cmd}")
else:
local.run(f"NEMO_EDITOR={editor} {cmd}")
2 changes: 1 addition & 1 deletion test/core/execution/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_init(self, mock_executor, mock_space, mock_srun):
assert callback.srun == mock_srun
assert callback.space == mock_space
assert callback.editor_started is False
assert callback.tunnel_name == "test_space.test_space"
assert callback.tunnel_name == "test_space"

def test_on_start_with_srun(self, mock_executor, mock_space, mock_srun):
"""Test on_start method with srun."""
Expand Down
Loading