38
38
39
39
def launch (
40
40
func : Callable ,
41
- func_args : tuple | None = None ,
42
- func_kwargs : dict [str , Any ] | None = None ,
41
+ args : tuple | None = None ,
42
+ kwargs : dict [str , Any ] | None = None ,
43
+ * ,
43
44
hostnames : list [str ] | Literal ["auto" , "slurm" ] = "auto" ,
44
45
workers_per_host : int | list [int ] | Literal ["auto" , "slurm" ] = "auto" ,
45
46
ssh_config_file : str | os .PathLike | None = None ,
@@ -57,50 +58,55 @@ def launch(
57
58
),
58
59
extra_env_vars : tuple [str , ...] = (),
59
60
env_file : str | os .PathLike | None = None ,
61
+ propagate_exceptions : bool = True ,
60
62
handler_factory : Callable [[], list [Handler ]] | Literal ["auto" ] | None = "auto" ,
61
63
) -> LaunchResult :
62
- """Launch a distributed PyTorch function on the specified nodes.
64
+ """Distribute and parallelize a function onto specified nodes and workers .
63
65
64
66
Arguments:
65
- func: Function to run on each worker.
66
- func_args : Positional arguments for ``func``.
67
- func_kwargs : Keyword arguments for ``func``.
67
+ func: Function to launch on each node and replicate for each worker.
68
+ args : Positional arguments for ``func``.
69
+ kwargs : Keyword arguments for ``func``.
68
70
hostnames: Nodes on which to launch the function.
69
- Defaults to nodes inferred from a SLURM environment or localhost.
70
- workers_per_host: Number of processes to run per node.
71
- Can specify different counts per node with a list.
71
+ Default: infer from localhost or SLURM.
72
+ workers_per_host: Number of processes to run (e.g. # of GPUs) per node.
72
73
ssh_config_file: Path to an SSH configuration file for connecting to nodes.
73
- Defaults to ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
74
+ Default: ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
74
75
backend: `Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
75
- for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set `None` to disable .
76
+ for worker process group. Set `None` to disable. Default: NCCL (GPU) or GLOO (CPU).
76
77
timeout: Worker process group timeout (seconds).
77
78
default_env_vars: Environment variables to copy from the launcher process to workers.
78
79
Supports bash pattern matching syntax.
79
80
extra_env_vars: Additional user-specified environment variables to copy.
80
- env_file: Path to a file (e.g., `.env`) with additional environment variables to copy.
81
- handler_factory: Function to build logging handlers that process agent and worker logs.
82
- Defaults to an automatic basic logging scheme.
81
+ env_file: Path to a file (e.g., ``.env``) with additional environment variables to copy.
82
+ propagate_exceptions: Raise exceptions from worker processes in the launcher.
83
+ If false, raises :obj:`WorkerFailedError` instead.
84
+ handler_factory: Function to customize processing of agent and worker logs with handlers.
83
85
84
86
Raises:
85
87
RuntimeError: If there are configuration issues.
86
88
AgentFailedError: If an agent fails, e.g. from an OS signal.
87
89
WorkerFailedError: If a worker fails, e.g. from a segmentation fault.
88
90
Exception: Any exception raised in a worker process is propagated.
89
91
"""
90
- return Launcher (
91
- hostnames = hostnames ,
92
- workers_per_host = workers_per_host ,
93
- ssh_config_file = ssh_config_file ,
94
- backend = backend ,
95
- timeout = timeout ,
96
- default_env_vars = default_env_vars ,
97
- extra_env_vars = extra_env_vars ,
98
- env_file = env_file ,
99
- ).run (
100
- func = func ,
101
- func_args = func_args ,
102
- func_kwargs = func_kwargs ,
103
- handler_factory = handler_factory ,
92
+ return (
93
+ Launcher (
94
+ hostnames = hostnames ,
95
+ workers_per_host = workers_per_host ,
96
+ ssh_config_file = ssh_config_file ,
97
+ backend = backend ,
98
+ timeout = timeout ,
99
+ default_env_vars = default_env_vars ,
100
+ extra_env_vars = extra_env_vars ,
101
+ env_file = env_file ,
102
+ propagate_exceptions = propagate_exceptions ,
103
+ )
104
+ .set_handler_factory (handler_factory )
105
+ .run (
106
+ func ,
107
+ args ,
108
+ kwargs ,
109
+ )
104
110
)
105
111
106
112
@@ -125,13 +131,24 @@ class Launcher:
125
131
)
126
132
extra_env_vars : tuple [str , ...] = ()
127
133
env_file : str | os .PathLike | None = None
134
+ propagate_exceptions : bool = True
135
+
136
+ def __post_init__ (self ) -> None :
137
+ """Initializing ``handler_factory``. Inclusion in ``__init__`` inhibits CLI generation."""
138
+ self .handler_factory : Callable [[], list [Handler ]] | Literal ["auto" ] | None = "auto"
139
+
140
+ def set_handler_factory (
141
+ self , factory : Callable [[], list [Handler ]] | Literal ["auto" ] | None
142
+ ) -> Launcher :
143
+ """Setter for log handler factory."""
144
+ self .handler_factory = factory
145
+ return self
128
146
129
147
def run ( # noqa: C901, PLR0912
130
148
self ,
131
149
func : Callable ,
132
- func_args : tuple | None = None ,
133
- func_kwargs : dict [str , Any ] | None = None ,
134
- handler_factory : Callable [[], list [Handler ]] | Literal ["auto" ] | None = "auto" ,
150
+ args : tuple | None = None ,
151
+ kwargs : dict [str , Any ] | None = None ,
135
152
) -> LaunchResult :
136
153
"""Run a function using the :mod:`torchrunx.Launcher` configuration."""
137
154
if not dist .is_available ():
@@ -155,7 +172,7 @@ def run( # noqa: C901, PLR0912
155
172
# Start logging server (recieves LogRecords from agents/workers)
156
173
157
174
logging_server_args = LoggingServerArgs (
158
- handler_factory = handler_factory ,
175
+ handler_factory = self . handler_factory ,
159
176
logging_hostname = launcher_hostname ,
160
177
logging_port = logging_port ,
161
178
hostnames = hostnames ,
@@ -211,7 +228,7 @@ def run( # noqa: C901, PLR0912
211
228
]
212
229
213
230
payload = LauncherPayload (
214
- fn = partial (func , * (func_args or ()), ** (func_kwargs or {})),
231
+ fn = partial (func , * (args or ()), ** (kwargs or {})),
215
232
hostnames = hostnames ,
216
233
worker_global_ranks = worker_global_ranks ,
217
234
worker_world_size = sum (workers_per_host ),
@@ -231,7 +248,9 @@ def run( # noqa: C901, PLR0912
231
248
for s in agent_statuses :
232
249
for value in s .return_values :
233
250
if isinstance (value , ExceptionFromWorker ):
234
- raise value .exception
251
+ if self .propagate_exceptions :
252
+ raise value .exception
253
+ raise WorkerFailedError from value .exception
235
254
if isinstance (value , WorkerFailedError ):
236
255
raise value
237
256
@@ -268,9 +287,9 @@ def __init__(self, hostnames: list[str], return_values: list[list[Any]]) -> None
268
287
"""Initialize from corresponding lists of hostnames and worker return values."""
269
288
self .results : dict [str , list [Any ]] = dict (zip (hostnames , return_values ))
270
289
271
- def index (self , hostname : str , rank : int ) -> Any :
290
+ def index (self , hostname : str , locak_rank : int ) -> Any :
272
291
"""Get return value from worker by host and local rank."""
273
- return self .results [hostname ][rank ]
292
+ return self .results [hostname ][locak_rank ]
274
293
275
294
def rank (self , i : int ) -> Any :
276
295
"""Get return value from worker by global rank."""
0 commit comments