@@ -108,8 +108,9 @@ def run( # noqa: C901, PLR0912, PLR0915
108
108
msg = "The torch.distributed package is not available."
109
109
raise RuntimeError (msg )
110
110
111
+ logger .debug ("Preparing launch environment." )
112
+
111
113
###
112
- logger .debug ("Resolving environment." )
113
114
114
115
hostnames , workers_per_host = resolve_environment (
115
116
self .hostnames , self .workers_per_host , ssh_config_file = self .ssh_config_file
@@ -183,11 +184,11 @@ def handler_factory() -> list[logging.Handler]:
183
184
184
185
log_process .start ()
185
186
186
- logger .debug ("Launching agents." )
187
-
188
187
# Start agents on each node
189
188
190
189
for i , hostname in enumerate (hostnames ):
190
+ logger .info (f'Launching "{ func .__name__ } " on { hostname } .' )
191
+
191
192
execute_command (
192
193
command = build_launch_command (
193
194
launcher_hostname = launcher_hostname ,
@@ -215,16 +216,15 @@ def handler_factory() -> list[logging.Handler]:
215
216
rank = 0 ,
216
217
)
217
218
218
- logger .debug ("Receiving agent details." )
219
-
220
219
# Sync initial payloads between launcher and agents
221
220
221
+ logger .debug ("Synchronizing launcher and agents." )
222
222
launcher_payload , agent_payloads = launcher_agent_group .sync_payloads (payload = payload )
223
223
224
- logger .debug ("Entering agent monitoring loop." )
225
-
226
224
# Monitor agent statuses (until failed or done)
227
225
226
+ logger .debug ("Entering agent monitoring loop." )
227
+
228
228
while True :
229
229
# could raise AgentFailedError
230
230
agent_statuses = launcher_agent_group .sync_agent_statuses (status = None )
@@ -238,17 +238,10 @@ def handler_factory() -> list[logging.Handler]:
238
238
raise v
239
239
240
240
if all (s .state == "done" for s in agent_statuses ):
241
- logger .debug ("All workers exited cleanly ." )
241
+ logger .info ("All workers completed successfully ." )
242
242
return_values : list [list [FunctionR ]] = [s .return_values for s in agent_statuses ] # pyright: ignore [reportAssignmentType]
243
243
return LaunchResult .from_returns (hostnames , return_values )
244
244
finally :
245
- logger .debug ("Stopping logging server." )
246
-
247
- if stop_logging_event is not None :
248
- stop_logging_event .set ()
249
- if log_process is not None :
250
- log_process .kill ()
251
-
252
245
# cleanup: SIGTERM all agents
253
246
if agent_payloads is not None :
254
247
for agent_payload , agent_hostname in zip (agent_payloads , hostnames ):
@@ -264,6 +257,13 @@ def handler_factory() -> list[logging.Handler]:
264
257
logger .debug ("Killing launcher-agent group." )
265
258
launcher_agent_group .shutdown ()
266
259
260
+ logger .debug ("Stopping logging server." )
261
+
262
+ if stop_logging_event is not None :
263
+ stop_logging_event .set ()
264
+ if log_process is not None :
265
+ log_process .kill ()
266
+
267
267
268
268
@dataclass
269
269
class LaunchResult (Generic [FunctionR ]):
0 commit comments