Skip to content

Commit c0baede

Browse files
committed
Moved failed errors to utils; raising AgentFailedError in _all_gather
1 parent 10fa1a0 commit c0baede

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

src/torchrunx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .errors import AgentFailedError, WorkerFailedError
21
from .launcher import Launcher, LaunchResult, launch
32
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
3+
from .utils import AgentFailedError, WorkerFailedError
44

55
__all__ = [
66
"AgentFailedError",

src/torchrunx/errors.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/torchrunx/launcher.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
import torch.distributed as dist
2222

2323
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
24-
from .errors import AgentFailedError, WorkerFailedError
2524
from .logging_utils import LogRecordSocketReceiver, default_handlers
2625
from .utils import (
2726
AgentStatus,
2827
ExceptionFromWorker,
2928
LauncherAgentGroup,
3029
LauncherPayload,
30+
WorkerFailedError,
3131
get_open_port,
3232
)
3333

@@ -144,11 +144,8 @@ def run( # noqa: C901, PLR0912
144144
# loop to monitor agent statuses (until failed or done)
145145

146146
while True:
147-
try:
148-
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
149-
except RuntimeError as e:
150-
# occurs if any agent dies and communication times out
151-
raise AgentFailedError from e
147+
# could raise AgentFailedError
148+
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
152149

153150
# raises specific exception if any agent fails
154151
for s in agent_statuses:

src/torchrunx/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import torch.distributed as dist
1111
from typing_extensions import Self
1212

13-
from .errors import WorkerFailedError
14-
1513
if TYPE_CHECKING:
1614
from torch.distributed.elastic.multiprocessing.api import RunProcsResult
1715

@@ -22,6 +20,13 @@ def get_open_port() -> int:
2220
return s.getsockname()[1]
2321

2422

23+
class AgentFailedError(Exception):
24+
pass
25+
26+
class WorkerFailedError(Exception):
27+
pass
28+
29+
2530
@dataclass
2631
class LauncherAgentGroup:
2732
launcher_hostname: str
@@ -52,11 +57,15 @@ def _deserialize(self, serialized: bytes) -> Any:
5257

5358
def _all_gather(self, obj: Any) -> list:
5459
"""gather object from every rank to list on every rank"""
55-
object_bytes = self._serialize(obj)
56-
object_list = [b""] * self.world_size
57-
# raises RuntimeError if timeout
58-
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
59-
return [self._deserialize(o) for o in object_list]
60+
try:
61+
object_bytes = self._serialize(obj)
62+
object_list = [b""] * self.world_size
63+
# raises RuntimeError if timeout
64+
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
65+
return [self._deserialize(o) for o in object_list]
66+
except RuntimeError as e:
67+
# occurs if launcher or any agent dies and communication times out
68+
raise AgentFailedError from e
6069

6170
def sync_payloads(
6271
self,

0 commit comments

Comments
 (0)