Skip to content

Commit 58eb486

Browse files
authored
Merge pull request #77 from apoorvkh/worker-exception
AgentKilledError
2 parents eb14b72 + bebd2ea commit 58eb486

File tree

3 files changed

+139
-128
lines changed

3 files changed

+139
-128
lines changed

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ API
55

66
.. autoclass:: torchrunx.LaunchResult
77
:members:
8+
9+
.. autoclass:: torchrunx.AgentKilledError

src/torchrunx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from .launcher import Launcher, LaunchResult, launch
1+
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
22
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
33

44
__all__ = [
5+
"AgentKilledError",
56
"Launcher",
67
"launch",
78
"LaunchResult",

src/torchrunx/launcher.py

Lines changed: 135 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -25,130 +25,8 @@
2525
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port
2626

2727

28-
def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
29-
if hostnames == "auto":
30-
return auto_hosts()
31-
if hostnames == "slurm":
32-
return slurm_hosts()
33-
return hostnames
34-
35-
36-
def resolve_workers_per_host(
37-
workers_per_host: int | list[int] | Literal["auto", "slurm"],
38-
num_hosts: int,
39-
) -> list[int]:
40-
if workers_per_host == "auto":
41-
workers_per_host = auto_workers()
42-
elif workers_per_host == "slurm":
43-
workers_per_host = slurm_workers()
44-
45-
if isinstance(workers_per_host, int):
46-
workers_per_host = [workers_per_host] * num_hosts
47-
elif len(workers_per_host) != num_hosts:
48-
msg = "len(workers_per_host) != len(hostnames)"
49-
raise ValueError(msg)
50-
51-
return workers_per_host
52-
53-
54-
def build_logging_server(
55-
log_handlers: list[Handler] | Literal["auto"] | None,
56-
launcher_hostname: str,
57-
hostnames: list[str],
58-
workers_per_host: list[int],
59-
log_dir: str | os.PathLike,
60-
log_level: int,
61-
) -> LogRecordSocketReceiver:
62-
if log_handlers is None:
63-
log_handlers = []
64-
elif log_handlers == "auto":
65-
log_handlers = default_handlers(
66-
hostnames=hostnames,
67-
workers_per_host=workers_per_host,
68-
log_dir=log_dir,
69-
log_level=log_level,
70-
)
71-
72-
return LogRecordSocketReceiver(
73-
host=launcher_hostname,
74-
port=get_open_port(),
75-
handlers=log_handlers,
76-
)
77-
78-
79-
def build_launch_command(
80-
launcher_hostname: str,
81-
launcher_port: int,
82-
logger_port: int,
83-
world_size: int,
84-
rank: int,
85-
env_vars: tuple[str, ...],
86-
env_file: str | os.PathLike | None,
87-
) -> str:
88-
# shlex.quote prevents shell injection here (resolves S602 in execute_command)
89-
90-
commands = []
91-
92-
current_dir = shlex.quote(str(Path.cwd()))
93-
commands.append("cd " + current_dir)
94-
95-
env_exports = []
96-
for k, v in os.environ.items():
97-
if any(fnmatch.fnmatch(k, e) for e in env_vars):
98-
env_exports.append(shlex.quote(f"{k}={v}"))
99-
100-
if len(env_exports) > 0:
101-
commands.append("export " + " ".join(env_exports))
102-
103-
if env_file is not None:
104-
commands.append("source " + shlex.quote(str(env_file)))
105-
106-
python = shlex.quote(sys.executable)
107-
launcher_hostname = shlex.quote(launcher_hostname)
108-
109-
commands.append(
110-
f"{python} -u -m torchrunx "
111-
f"--launcher-hostname {launcher_hostname} "
112-
f"--launcher-port {launcher_port} "
113-
f"--logger-port {logger_port} "
114-
f"--world-size {world_size} "
115-
f"--rank {rank}",
116-
)
117-
118-
return " && ".join(commands)
119-
120-
121-
def execute_command(
122-
command: str,
123-
hostname: str,
124-
ssh_config_file: str | os.PathLike | None = None,
125-
) -> None:
126-
is_localhost = True
127-
_hostname_or_ip = hostname
128-
try:
129-
_ip = ipaddress.ip_address(_hostname_or_ip)
130-
except ValueError:
131-
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
132-
if not _ip.is_loopback:
133-
# compare local interface addresses between host and localhost
134-
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
135-
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
136-
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0
137-
138-
if is_localhost:
139-
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
140-
# Made sure to shlex.quote arguments in build_command to prevent shell injection
141-
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
142-
else:
143-
runtime_ssh_path = ssh_config_file
144-
if isinstance(ssh_config_file, os.PathLike):
145-
runtime_ssh_path = str(ssh_config_file)
146-
147-
with fabric.Connection(
148-
host=hostname,
149-
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
150-
) as conn:
151-
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)
28+
class AgentKilledError(Exception):
29+
pass
15230

15331

15432
@dataclass
@@ -263,8 +141,11 @@ def run( # noqa: C901, PLR0912
263141
# loop to monitor agent statuses (until failed or done)
264142

265143
while True:
266-
# raises RuntimeError if communication timeout due to death of any agent
267-
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
144+
try:
145+
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
146+
except RuntimeError as e:
147+
# occurs if any agent dies and communication times out
148+
raise AgentKilledError from e
268149

269150
# raises specific exception if any agent fails
270151
for s in agent_statuses:
@@ -334,7 +215,8 @@ def launch(
334215
:param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
335216
:param extra_env_vars: Additional, user-specified variables to copy.
336217
:param env_file: A file (like ``.env``) with additional environment variables to copy.
337-
:raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes
218+
:raises RuntimeError: If ``torch.distributed`` not available
219+
:raises AgentKilledError: If any agent is killed
338220
:raises Exception: Propagates exceptions raised in worker processes
339221
""" # noqa: E501
340222
return Launcher(
@@ -409,3 +291,129 @@ def value(self, rank: int) -> Any:
409291

410292
msg = f"Rank {rank} larger than world_size"
411293
raise ValueError(msg)
294+
295+
296+
def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
297+
if hostnames == "auto":
298+
return auto_hosts()
299+
if hostnames == "slurm":
300+
return slurm_hosts()
301+
return hostnames
302+
303+
304+
def resolve_workers_per_host(
305+
workers_per_host: int | list[int] | Literal["auto", "slurm"],
306+
num_hosts: int,
307+
) -> list[int]:
308+
if workers_per_host == "auto":
309+
workers_per_host = auto_workers()
310+
elif workers_per_host == "slurm":
311+
workers_per_host = slurm_workers()
312+
313+
if isinstance(workers_per_host, int):
314+
workers_per_host = [workers_per_host] * num_hosts
315+
elif len(workers_per_host) != num_hosts:
316+
msg = "len(workers_per_host) != len(hostnames)"
317+
raise ValueError(msg)
318+
319+
return workers_per_host
320+
321+
322+
def build_logging_server(
323+
log_handlers: list[Handler] | Literal["auto"] | None,
324+
launcher_hostname: str,
325+
hostnames: list[str],
326+
workers_per_host: list[int],
327+
log_dir: str | os.PathLike,
328+
log_level: int,
329+
) -> LogRecordSocketReceiver:
330+
if log_handlers is None:
331+
log_handlers = []
332+
elif log_handlers == "auto":
333+
log_handlers = default_handlers(
334+
hostnames=hostnames,
335+
workers_per_host=workers_per_host,
336+
log_dir=log_dir,
337+
log_level=log_level,
338+
)
339+
340+
return LogRecordSocketReceiver(
341+
host=launcher_hostname,
342+
port=get_open_port(),
343+
handlers=log_handlers,
344+
)
345+
346+
347+
def build_launch_command(
348+
launcher_hostname: str,
349+
launcher_port: int,
350+
logger_port: int,
351+
world_size: int,
352+
rank: int,
353+
env_vars: tuple[str, ...],
354+
env_file: str | os.PathLike | None,
355+
) -> str:
356+
# shlex.quote prevents shell injection here (resolves S602 in execute_command)
357+
358+
commands = []
359+
360+
current_dir = shlex.quote(str(Path.cwd()))
361+
commands.append("cd " + current_dir)
362+
363+
env_exports = []
364+
for k, v in os.environ.items():
365+
if any(fnmatch.fnmatch(k, e) for e in env_vars):
366+
env_exports.append(shlex.quote(f"{k}={v}"))
367+
368+
if len(env_exports) > 0:
369+
commands.append("export " + " ".join(env_exports))
370+
371+
if env_file is not None:
372+
commands.append("source " + shlex.quote(str(env_file)))
373+
374+
python = shlex.quote(sys.executable)
375+
launcher_hostname = shlex.quote(launcher_hostname)
376+
377+
commands.append(
378+
f"{python} -u -m torchrunx "
379+
f"--launcher-hostname {launcher_hostname} "
380+
f"--launcher-port {launcher_port} "
381+
f"--logger-port {logger_port} "
382+
f"--world-size {world_size} "
383+
f"--rank {rank}",
384+
)
385+
386+
return " && ".join(commands)
387+
388+
389+
def execute_command(
390+
command: str,
391+
hostname: str,
392+
ssh_config_file: str | os.PathLike | None = None,
393+
) -> None:
394+
is_localhost = True
395+
_hostname_or_ip = hostname
396+
try:
397+
_ip = ipaddress.ip_address(_hostname_or_ip)
398+
except ValueError:
399+
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
400+
if not _ip.is_loopback:
401+
# compare local interface addresses between host and localhost
402+
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
403+
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
404+
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0
405+
406+
if is_localhost:
407+
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
408+
# Made sure to shlex.quote arguments in build_command to prevent shell injection
409+
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
410+
else:
411+
runtime_ssh_path = ssh_config_file
412+
if isinstance(ssh_config_file, os.PathLike):
413+
runtime_ssh_path = str(ssh_config_file)
414+
415+
with fabric.Connection(
416+
host=hostname,
417+
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
418+
) as conn:
419+
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)

0 commit comments

Comments
 (0)