|
25 | 25 | from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port
|
26 | 26 |
|
27 | 27 |
|
28 |
| -def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: |
29 |
| - if hostnames == "auto": |
30 |
| - return auto_hosts() |
31 |
| - if hostnames == "slurm": |
32 |
| - return slurm_hosts() |
33 |
| - return hostnames |
34 |
| - |
35 |
| - |
36 |
| -def resolve_workers_per_host( |
37 |
| - workers_per_host: int | list[int] | Literal["auto", "slurm"], |
38 |
| - num_hosts: int, |
39 |
| -) -> list[int]: |
40 |
| - if workers_per_host == "auto": |
41 |
| - workers_per_host = auto_workers() |
42 |
| - elif workers_per_host == "slurm": |
43 |
| - workers_per_host = slurm_workers() |
44 |
| - |
45 |
| - if isinstance(workers_per_host, int): |
46 |
| - workers_per_host = [workers_per_host] * num_hosts |
47 |
| - elif len(workers_per_host) != num_hosts: |
48 |
| - msg = "len(workers_per_host) != len(hostnames)" |
49 |
| - raise ValueError(msg) |
50 |
| - |
51 |
| - return workers_per_host |
52 |
| - |
53 |
| - |
54 |
| -def build_logging_server( |
55 |
| - log_handlers: list[Handler] | Literal["auto"] | None, |
56 |
| - launcher_hostname: str, |
57 |
| - hostnames: list[str], |
58 |
| - workers_per_host: list[int], |
59 |
| - log_dir: str | os.PathLike, |
60 |
| - log_level: int, |
61 |
| -) -> LogRecordSocketReceiver: |
62 |
| - if log_handlers is None: |
63 |
| - log_handlers = [] |
64 |
| - elif log_handlers == "auto": |
65 |
| - log_handlers = default_handlers( |
66 |
| - hostnames=hostnames, |
67 |
| - workers_per_host=workers_per_host, |
68 |
| - log_dir=log_dir, |
69 |
| - log_level=log_level, |
70 |
| - ) |
71 |
| - |
72 |
| - return LogRecordSocketReceiver( |
73 |
| - host=launcher_hostname, |
74 |
| - port=get_open_port(), |
75 |
| - handlers=log_handlers, |
76 |
| - ) |
77 |
| - |
78 |
| - |
79 |
| -def build_launch_command( |
80 |
| - launcher_hostname: str, |
81 |
| - launcher_port: int, |
82 |
| - logger_port: int, |
83 |
| - world_size: int, |
84 |
| - rank: int, |
85 |
| - env_vars: tuple[str, ...], |
86 |
| - env_file: str | os.PathLike | None, |
87 |
| -) -> str: |
88 |
| - # shlex.quote prevents shell injection here (resolves S602 in execute_command) |
89 |
| - |
90 |
| - commands = [] |
91 |
| - |
92 |
| - current_dir = shlex.quote(str(Path.cwd())) |
93 |
| - commands.append("cd " + current_dir) |
94 |
| - |
95 |
| - env_exports = [] |
96 |
| - for k, v in os.environ.items(): |
97 |
| - if any(fnmatch.fnmatch(k, e) for e in env_vars): |
98 |
| - env_exports.append(shlex.quote(f"{k}={v}")) |
99 |
| - |
100 |
| - if len(env_exports) > 0: |
101 |
| - commands.append("export " + " ".join(env_exports)) |
102 |
| - |
103 |
| - if env_file is not None: |
104 |
| - commands.append("source " + shlex.quote(str(env_file))) |
105 |
| - |
106 |
| - python = shlex.quote(sys.executable) |
107 |
| - launcher_hostname = shlex.quote(launcher_hostname) |
108 |
| - |
109 |
| - commands.append( |
110 |
| - f"{python} -u -m torchrunx " |
111 |
| - f"--launcher-hostname {launcher_hostname} " |
112 |
| - f"--launcher-port {launcher_port} " |
113 |
| - f"--logger-port {logger_port} " |
114 |
| - f"--world-size {world_size} " |
115 |
| - f"--rank {rank}", |
116 |
| - ) |
117 |
| - |
118 |
| - return " && ".join(commands) |
119 |
| - |
120 |
| - |
121 |
| -def execute_command( |
122 |
| - command: str, |
123 |
| - hostname: str, |
124 |
| - ssh_config_file: str | os.PathLike | None = None, |
125 |
| -) -> None: |
126 |
| - is_localhost = True |
127 |
| - _hostname_or_ip = hostname |
128 |
| - try: |
129 |
| - _ip = ipaddress.ip_address(_hostname_or_ip) |
130 |
| - except ValueError: |
131 |
| - _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) |
132 |
| - if not _ip.is_loopback: |
133 |
| - # compare local interface addresses between host and localhost |
134 |
| - _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] |
135 |
| - _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] |
136 |
| - is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 |
137 |
| - |
138 |
| - if is_localhost: |
139 |
| - # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) |
140 |
| - # Made sure to shlex.quote arguments in build_command to prevent shell injection |
141 |
| - subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 |
142 |
| - else: |
143 |
| - runtime_ssh_path = ssh_config_file |
144 |
| - if isinstance(ssh_config_file, os.PathLike): |
145 |
| - runtime_ssh_path = str(ssh_config_file) |
146 |
| - |
147 |
| - with fabric.Connection( |
148 |
| - host=hostname, |
149 |
| - config=fabric.Config(runtime_ssh_path=runtime_ssh_path), |
150 |
| - ) as conn: |
151 |
| - conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) |
| 28 | +class AgentKilledError(Exception): |
| 29 | + pass |
152 | 30 |
|
153 | 31 |
|
154 | 32 | @dataclass
|
@@ -263,8 +141,11 @@ def run( # noqa: C901, PLR0912
|
263 | 141 | # loop to monitor agent statuses (until failed or done)
|
264 | 142 |
|
265 | 143 | while True:
|
266 |
| - # raises RuntimeError if communication timeout due to death of any agent |
267 |
| - agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) |
| 144 | + try: |
| 145 | + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) |
| 146 | + except RuntimeError as e: |
| 147 | + # occurs if any agent dies and communication times out |
| 148 | + raise AgentKilledError from e |
268 | 149 |
|
269 | 150 | # raises specific exception if any agent fails
|
270 | 151 | for s in agent_statuses:
|
@@ -334,7 +215,8 @@ def launch(
|
334 | 215 | :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
|
335 | 216 | :param extra_env_vars: Additional, user-specified variables to copy.
|
336 | 217 | :param env_file: A file (like ``.env``) with additional environment variables to copy.
|
337 |
| - :raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes |
| 218 | + :raises RuntimeError: If ``torch.distributed`` not available |
| 219 | + :raises AgentKilledError: If any agent is killed |
338 | 220 | :raises Exception: Propagates exceptions raised in worker processes
|
339 | 221 | """ # noqa: E501
|
340 | 222 | return Launcher(
|
@@ -409,3 +291,129 @@ def value(self, rank: int) -> Any:
|
409 | 291 |
|
410 | 292 | msg = f"Rank {rank} larger than world_size"
|
411 | 293 | raise ValueError(msg)
|
| 294 | + |
| 295 | + |
| 296 | +def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: |
| 297 | + if hostnames == "auto": |
| 298 | + return auto_hosts() |
| 299 | + if hostnames == "slurm": |
| 300 | + return slurm_hosts() |
| 301 | + return hostnames |
| 302 | + |
| 303 | + |
| 304 | +def resolve_workers_per_host( |
| 305 | + workers_per_host: int | list[int] | Literal["auto", "slurm"], |
| 306 | + num_hosts: int, |
| 307 | +) -> list[int]: |
| 308 | + if workers_per_host == "auto": |
| 309 | + workers_per_host = auto_workers() |
| 310 | + elif workers_per_host == "slurm": |
| 311 | + workers_per_host = slurm_workers() |
| 312 | + |
| 313 | + if isinstance(workers_per_host, int): |
| 314 | + workers_per_host = [workers_per_host] * num_hosts |
| 315 | + elif len(workers_per_host) != num_hosts: |
| 316 | + msg = "len(workers_per_host) != len(hostnames)" |
| 317 | + raise ValueError(msg) |
| 318 | + |
| 319 | + return workers_per_host |
| 320 | + |
| 321 | + |
| 322 | +def build_logging_server( |
| 323 | + log_handlers: list[Handler] | Literal["auto"] | None, |
| 324 | + launcher_hostname: str, |
| 325 | + hostnames: list[str], |
| 326 | + workers_per_host: list[int], |
| 327 | + log_dir: str | os.PathLike, |
| 328 | + log_level: int, |
| 329 | +) -> LogRecordSocketReceiver: |
| 330 | + if log_handlers is None: |
| 331 | + log_handlers = [] |
| 332 | + elif log_handlers == "auto": |
| 333 | + log_handlers = default_handlers( |
| 334 | + hostnames=hostnames, |
| 335 | + workers_per_host=workers_per_host, |
| 336 | + log_dir=log_dir, |
| 337 | + log_level=log_level, |
| 338 | + ) |
| 339 | + |
| 340 | + return LogRecordSocketReceiver( |
| 341 | + host=launcher_hostname, |
| 342 | + port=get_open_port(), |
| 343 | + handlers=log_handlers, |
| 344 | + ) |
| 345 | + |
| 346 | + |
| 347 | +def build_launch_command( |
| 348 | + launcher_hostname: str, |
| 349 | + launcher_port: int, |
| 350 | + logger_port: int, |
| 351 | + world_size: int, |
| 352 | + rank: int, |
| 353 | + env_vars: tuple[str, ...], |
| 354 | + env_file: str | os.PathLike | None, |
| 355 | +) -> str: |
| 356 | + # shlex.quote prevents shell injection here (resolves S602 in execute_command) |
| 357 | + |
| 358 | + commands = [] |
| 359 | + |
| 360 | + current_dir = shlex.quote(str(Path.cwd())) |
| 361 | + commands.append("cd " + current_dir) |
| 362 | + |
| 363 | + env_exports = [] |
| 364 | + for k, v in os.environ.items(): |
| 365 | + if any(fnmatch.fnmatch(k, e) for e in env_vars): |
| 366 | + env_exports.append(shlex.quote(f"{k}={v}")) |
| 367 | + |
| 368 | + if len(env_exports) > 0: |
| 369 | + commands.append("export " + " ".join(env_exports)) |
| 370 | + |
| 371 | + if env_file is not None: |
| 372 | + commands.append("source " + shlex.quote(str(env_file))) |
| 373 | + |
| 374 | + python = shlex.quote(sys.executable) |
| 375 | + launcher_hostname = shlex.quote(launcher_hostname) |
| 376 | + |
| 377 | + commands.append( |
| 378 | + f"{python} -u -m torchrunx " |
| 379 | + f"--launcher-hostname {launcher_hostname} " |
| 380 | + f"--launcher-port {launcher_port} " |
| 381 | + f"--logger-port {logger_port} " |
| 382 | + f"--world-size {world_size} " |
| 383 | + f"--rank {rank}", |
| 384 | + ) |
| 385 | + |
| 386 | + return " && ".join(commands) |
| 387 | + |
| 388 | + |
| 389 | +def execute_command( |
| 390 | + command: str, |
| 391 | + hostname: str, |
| 392 | + ssh_config_file: str | os.PathLike | None = None, |
| 393 | +) -> None: |
| 394 | + is_localhost = True |
| 395 | + _hostname_or_ip = hostname |
| 396 | + try: |
| 397 | + _ip = ipaddress.ip_address(_hostname_or_ip) |
| 398 | + except ValueError: |
| 399 | + _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) |
| 400 | + if not _ip.is_loopback: |
| 401 | + # compare local interface addresses between host and localhost |
| 402 | + _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] |
| 403 | + _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] |
| 404 | + is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 |
| 405 | + |
| 406 | + if is_localhost: |
| 407 | + # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) |
| 408 | + # Made sure to shlex.quote arguments in build_command to prevent shell injection |
| 409 | + subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 |
| 410 | + else: |
| 411 | + runtime_ssh_path = ssh_config_file |
| 412 | + if isinstance(ssh_config_file, os.PathLike): |
| 413 | + runtime_ssh_path = str(ssh_config_file) |
| 414 | + |
| 415 | + with fabric.Connection( |
| 416 | + host=hostname, |
| 417 | + config=fabric.Config(runtime_ssh_path=runtime_ssh_path), |
| 418 | + ) as conn: |
| 419 | + conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) |
0 commit comments