Skip to content

Commit 2824b6a

Browse files
committed
some extra error docs
1 parent c0baede commit 2824b6a

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/torchrunx/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,10 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
169169
status = None
170170
while True:
171171
if status is None or status.state == "running":
172-
status = AgentStatus.from_result(ctx.wait(5))
172+
# status can contain ExceptionFromWorker or WorkerFailedError
173+
status = AgentStatus.from_result(result=ctx.wait(5))
173174

175+
# can raise AgentFailedError in launcher and all agents
174176
agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)
175177

176178
all_done = all(s.state == "done" for s in agent_statuses)

src/torchrunx/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_open_port() -> int:
2323
class AgentFailedError(Exception):
2424
pass
2525

26+
2627
class WorkerFailedError(Exception):
2728
pass
2829

@@ -108,7 +109,7 @@ class ExceptionFromWorker:
108109
@dataclass
109110
class AgentStatus:
110111
state: Literal["running", "failed", "done"]
111-
return_values: list[Any | ExceptionFromWorker] = field(
112+
return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field(
112113
default_factory=list
113114
) # indexed by local rank
114115

0 commit comments

Comments
 (0)