Skip to content

Commit c2d51bf

Browse files
committed
update launcher API
1 parent ffc45d1 commit c2d51bf

File tree

1 file changed

+55
-36
lines changed

1 file changed

+55
-36
lines changed

src/torchrunx/launcher.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838

3939
def launch(
4040
func: Callable,
41-
func_args: tuple | None = None,
42-
func_kwargs: dict[str, Any] | None = None,
41+
args: tuple | None = None,
42+
kwargs: dict[str, Any] | None = None,
43+
*,
4344
hostnames: list[str] | Literal["auto", "slurm"] = "auto",
4445
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto",
4546
ssh_config_file: str | os.PathLike | None = None,
@@ -57,50 +58,55 @@ def launch(
5758
),
5859
extra_env_vars: tuple[str, ...] = (),
5960
env_file: str | os.PathLike | None = None,
61+
propagate_exceptions: bool = True,
6062
handler_factory: Callable[[], list[Handler]] | Literal["auto"] | None = "auto",
6163
) -> LaunchResult:
62-
"""Launch a distributed PyTorch function on the specified nodes.
64+
"""Distribute and parallelize a function onto specified nodes and workers.
6365
6466
Arguments:
65-
func: Function to run on each worker.
66-
func_args: Positional arguments for ``func``.
67-
func_kwargs: Keyword arguments for ``func``.
67+
func: Function to launch on each node and replicate for each worker.
68+
args: Positional arguments for ``func``.
69+
kwargs: Keyword arguments for ``func``.
6870
hostnames: Nodes on which to launch the function.
69-
Defaults to nodes inferred from a SLURM environment or localhost.
70-
workers_per_host: Number of processes to run per node.
71-
Can specify different counts per node with a list.
71+
Default: infer from localhost or SLURM.
72+
workers_per_host: Number of processes to run (e.g. # of GPUs) per node.
7273
ssh_config_file: Path to an SSH configuration file for connecting to nodes.
73-
Defaults to ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
74+
Default: ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
7475
backend: `Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
75-
for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set `None` to disable.
76+
for worker process group. Set `None` to disable. Default: NCCL (GPU) or GLOO (CPU).
7677
timeout: Worker process group timeout (seconds).
7778
default_env_vars: Environment variables to copy from the launcher process to workers.
7879
Supports bash pattern matching syntax.
7980
extra_env_vars: Additional user-specified environment variables to copy.
80-
env_file: Path to a file (e.g., `.env`) with additional environment variables to copy.
81-
handler_factory: Function to build logging handlers that process agent and worker logs.
82-
Defaults to an automatic basic logging scheme.
81+
env_file: Path to a file (e.g., ``.env``) with additional environment variables to copy.
82+
propagate_exceptions: Raise exceptions from worker processes in the launcher.
83+
If false, raises :obj:`WorkerFailedError` instead.
84+
handler_factory: Function to customize processing of agent and worker logs with handlers.
8385
8486
Raises:
8587
RuntimeError: If there are configuration issues.
8688
AgentFailedError: If an agent fails, e.g. from an OS signal.
8789
WorkerFailedError: If a worker fails, e.g. from a segmentation fault.
8890
Exception: Any exception raised in a worker process is propagated.
8991
"""
90-
return Launcher(
91-
hostnames=hostnames,
92-
workers_per_host=workers_per_host,
93-
ssh_config_file=ssh_config_file,
94-
backend=backend,
95-
timeout=timeout,
96-
default_env_vars=default_env_vars,
97-
extra_env_vars=extra_env_vars,
98-
env_file=env_file,
99-
).run(
100-
func=func,
101-
func_args=func_args,
102-
func_kwargs=func_kwargs,
103-
handler_factory=handler_factory,
92+
return (
93+
Launcher(
94+
hostnames=hostnames,
95+
workers_per_host=workers_per_host,
96+
ssh_config_file=ssh_config_file,
97+
backend=backend,
98+
timeout=timeout,
99+
default_env_vars=default_env_vars,
100+
extra_env_vars=extra_env_vars,
101+
env_file=env_file,
102+
propagate_exceptions=propagate_exceptions,
103+
)
104+
.set_handler_factory(handler_factory)
105+
.run(
106+
func,
107+
args,
108+
kwargs,
109+
)
104110
)
105111

106112

@@ -125,13 +131,24 @@ class Launcher:
125131
)
126132
extra_env_vars: tuple[str, ...] = ()
127133
env_file: str | os.PathLike | None = None
134+
propagate_exceptions: bool = True
135+
136+
def __post_init__(self) -> None:
137+
"""Initializing ``handler_factory``. Inclusion in ``__init__`` inhibits CLI generation."""
138+
self.handler_factory: Callable[[], list[Handler]] | Literal["auto"] | None = "auto"
139+
140+
def set_handler_factory(
141+
self, factory: Callable[[], list[Handler]] | Literal["auto"] | None
142+
) -> Launcher:
143+
"""Setter for log handler factory."""
144+
self.handler_factory = factory
145+
return self
128146

129147
def run( # noqa: C901, PLR0912
130148
self,
131149
func: Callable,
132-
func_args: tuple | None = None,
133-
func_kwargs: dict[str, Any] | None = None,
134-
handler_factory: Callable[[], list[Handler]] | Literal["auto"] | None = "auto",
150+
args: tuple | None = None,
151+
kwargs: dict[str, Any] | None = None,
135152
) -> LaunchResult:
136153
"""Run a function using the :mod:`torchrunx.Launcher` configuration."""
137154
if not dist.is_available():
@@ -155,7 +172,7 @@ def run( # noqa: C901, PLR0912
155172
# Start logging server (recieves LogRecords from agents/workers)
156173

157174
logging_server_args = LoggingServerArgs(
158-
handler_factory=handler_factory,
175+
handler_factory=self.handler_factory,
159176
logging_hostname=launcher_hostname,
160177
logging_port=logging_port,
161178
hostnames=hostnames,
@@ -211,7 +228,7 @@ def run( # noqa: C901, PLR0912
211228
]
212229

213230
payload = LauncherPayload(
214-
fn=partial(func, *(func_args or ()), **(func_kwargs or {})),
231+
fn=partial(func, *(args or ()), **(kwargs or {})),
215232
hostnames=hostnames,
216233
worker_global_ranks=worker_global_ranks,
217234
worker_world_size=sum(workers_per_host),
@@ -231,7 +248,9 @@ def run( # noqa: C901, PLR0912
231248
for s in agent_statuses:
232249
for value in s.return_values:
233250
if isinstance(value, ExceptionFromWorker):
234-
raise value.exception
251+
if self.propagate_exceptions:
252+
raise value.exception
253+
raise WorkerFailedError from value.exception
235254
if isinstance(value, WorkerFailedError):
236255
raise value
237256

@@ -268,9 +287,9 @@ def __init__(self, hostnames: list[str], return_values: list[list[Any]]) -> None
268287
"""Initialize from corresponding lists of hostnames and worker return values."""
269288
self.results: dict[str, list[Any]] = dict(zip(hostnames, return_values))
270289

271-
def index(self, hostname: str, rank: int) -> Any:
290+
def index(self, hostname: str, locak_rank: int) -> Any:
272291
"""Get return value from worker by host and local rank."""
273-
return self.results[hostname][rank]
292+
return self.results[hostname][locak_rank]
274293

275294
def rank(self, i: int) -> Any:
276295
"""Get return value from worker by global rank."""

0 commit comments

Comments
 (0)