Skip to content

Commit 10fa1a0

Browse files
committed
renamed to WorkerFailedError and AgentFailedError
1 parent 5ec7fd9 commit 10fa1a0

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

src/torchrunx/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
1+
from .errors import AgentFailedError, WorkerFailedError
2+
from .launcher import Launcher, LaunchResult, launch
23
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
34

45
__all__ = [
5-
"AgentKilledError",
6+
"AgentFailedError",
7+
"WorkerFailedError",
68
"Launcher",
79
"launch",
810
"LaunchResult",

src/torchrunx/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class AgentFailedError(Exception):
2+
pass
3+
4+
class WorkerFailedError(Exception):
5+
pass

src/torchrunx/launcher.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,17 @@
2121
import torch.distributed as dist
2222

2323
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
24+
from .errors import AgentFailedError, WorkerFailedError
2425
from .logging_utils import LogRecordSocketReceiver, default_handlers
2526
from .utils import (
2627
AgentStatus,
2728
ExceptionFromWorker,
2829
LauncherAgentGroup,
2930
LauncherPayload,
30-
WorkerKilledError,
3131
get_open_port,
3232
)
3333

3434

35-
class AgentKilledError(Exception):
36-
pass
37-
38-
3935
@dataclass
4036
class Launcher:
4137
hostnames: list[str] | Literal["auto", "slurm"] = "auto"
@@ -152,14 +148,14 @@ def run( # noqa: C901, PLR0912
152148
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
153149
except RuntimeError as e:
154150
# occurs if any agent dies and communication times out
155-
raise AgentKilledError from e
151+
raise AgentFailedError from e
156152

157153
# raises specific exception if any agent fails
158154
for s in agent_statuses:
159155
for value in s.return_values:
160156
if isinstance(value, ExceptionFromWorker):
161157
raise value.exception
162-
if isinstance(value, WorkerKilledError):
158+
if isinstance(value, WorkerFailedError):
163159
raise value
164160

165161
if all(s.state == "done" for s in agent_statuses):

src/torchrunx/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch.distributed as dist
1111
from typing_extensions import Self
1212

13+
from .errors import WorkerFailedError
14+
1315
if TYPE_CHECKING:
1416
from torch.distributed.elastic.multiprocessing.api import RunProcsResult
1517

@@ -94,11 +96,6 @@ class ExceptionFromWorker:
9496
exception: Exception
9597

9698

97-
@dataclass
98-
class WorkerKilledError(Exception):
99-
failure: str
100-
101-
10299
@dataclass
103100
class AgentStatus:
104101
state: Literal["running", "failed", "done"]
@@ -111,7 +108,7 @@ def from_result(cls, result: RunProcsResult | None) -> Self:
111108
if result is None:
112109
return cls(state="running")
113110
for local_rank, failure in result.failures.items():
114-
result.return_values[local_rank] = WorkerKilledError(failure.message)
111+
result.return_values[local_rank] = WorkerFailedError(failure.message)
115112
return_values = list(result.return_values.values())
116113
failed = any(isinstance(v, ExceptionFromWorker) for v in return_values)
117114
state = "failed" if failed else "done"

0 commit comments

Comments
 (0)