Skip to content

Commit e31f967

Browse files
authored
Merge pull request #78 from apoorvkh/worker-killed-error
Worker killed error
2 parents 58eb486 + 2824b6a commit e31f967

File tree

4 files changed

+47
-27
lines changed

4 files changed

+47
-27
lines changed

src/torchrunx/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
1+
from .launcher import Launcher, LaunchResult, launch
22
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
3+
from .utils import AgentFailedError, WorkerFailedError
34

45
__all__ = [
5-
"AgentKilledError",
6+
"AgentFailedError",
7+
"WorkerFailedError",
68
"Launcher",
79
"launch",
810
"LaunchResult",

src/torchrunx/agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from .utils import (
2020
AgentPayload,
2121
AgentStatus,
22+
ExceptionFromWorker,
2223
LauncherAgentGroup,
23-
WorkerException,
2424
get_open_port,
2525
)
2626

@@ -52,7 +52,7 @@ def deserialize(self) -> WorkerArgs:
5252
return cloudpickle.loads(self.bytes)
5353

5454

55-
def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
55+
def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker:
5656
worker_args: WorkerArgs = serialized_worker_args.deserialize()
5757

5858
logger = logging.getLogger()
@@ -96,7 +96,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce
9696
return worker_args.function()
9797
except Exception as e:
9898
traceback.print_exc()
99-
return WorkerException(exception=e)
99+
return ExceptionFromWorker(exception=e)
100100
finally:
101101
sys.stdout.flush()
102102
sys.stderr.flush()
@@ -155,7 +155,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
155155
)
156156
for i in range(num_workers)
157157
},
158+
# environment variables from agent are already automatically copied to workers
158159
envs={i: {} for i in range(num_workers)},
160+
# we handle logging ourselves, so we can discard these
159161
**(
160162
{"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
161163
if torch.__version__ >= "2.3"
@@ -167,8 +169,10 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
167169
status = None
168170
while True:
169171
if status is None or status.state == "running":
170-
status = AgentStatus.from_result(ctx.wait(5))
172+
# status can contain ExceptionFromWorker or WorkerFailedError
173+
status = AgentStatus.from_result(result=ctx.wait(5))
171174

175+
# can raise AgentFailedError in launcher and all agents
172176
agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)
173177

174178
all_done = all(s.state == "done" for s in agent_statuses)

src/torchrunx/launcher.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222

2323
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
2424
from .logging_utils import LogRecordSocketReceiver, default_handlers
25-
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port
26-
27-
28-
class AgentKilledError(Exception):
29-
pass
25+
from .utils import (
26+
AgentStatus,
27+
ExceptionFromWorker,
28+
LauncherAgentGroup,
29+
LauncherPayload,
30+
WorkerFailedError,
31+
get_open_port,
32+
)
3033

3134

3235
@dataclass
@@ -141,17 +144,16 @@ def run( # noqa: C901, PLR0912
141144
# loop to monitor agent statuses (until failed or done)
142145

143146
while True:
144-
try:
145-
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
146-
except RuntimeError as e:
147-
# occurs if any agent dies and communication times out
148-
raise AgentKilledError from e
147+
# could raise AgentFailedError
148+
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
149149

150150
# raises specific exception if any agent fails
151151
for s in agent_statuses:
152152
for value in s.return_values:
153-
if isinstance(value, WorkerException):
153+
if isinstance(value, ExceptionFromWorker):
154154
raise value.exception
155+
if isinstance(value, WorkerFailedError):
156+
raise value
155157

156158
if all(s.state == "done" for s in agent_statuses):
157159
break

src/torchrunx/utils.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ def get_open_port() -> int:
2020
return s.getsockname()[1]
2121

2222

23+
class AgentFailedError(Exception):
24+
pass
25+
26+
27+
class WorkerFailedError(Exception):
28+
pass
29+
30+
2331
@dataclass
2432
class LauncherAgentGroup:
2533
launcher_hostname: str
@@ -50,11 +58,15 @@ def _deserialize(self, serialized: bytes) -> Any:
5058

5159
def _all_gather(self, obj: Any) -> list:
5260
"""gather object from every rank to list on every rank"""
53-
object_bytes = self._serialize(obj)
54-
object_list = [b""] * self.world_size
55-
# raises RuntimeError if timeout
56-
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
57-
return [self._deserialize(o) for o in object_list]
61+
try:
62+
object_bytes = self._serialize(obj)
63+
object_list = [b""] * self.world_size
64+
# raises RuntimeError if timeout
65+
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
66+
return [self._deserialize(o) for o in object_list]
67+
except RuntimeError as e:
68+
# occurs if launcher or any agent dies and communication times out
69+
raise AgentFailedError from e
5870

5971
def sync_payloads(
6072
self,
@@ -90,25 +102,25 @@ class AgentPayload:
90102

91103

92104
@dataclass
93-
class WorkerException:
105+
class ExceptionFromWorker:
94106
exception: Exception
95107

96108

97109
@dataclass
98110
class AgentStatus:
99111
state: Literal["running", "failed", "done"]
100-
return_values: list[Any | WorkerException] = field(
112+
return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field(
101113
default_factory=list
102114
) # indexed by local rank
103115

104116
@classmethod
105117
def from_result(cls, result: RunProcsResult | None) -> Self:
106118
if result is None:
107119
return cls(state="running")
108-
120+
for local_rank, failure in result.failures.items():
121+
result.return_values[local_rank] = WorkerFailedError(failure.message)
109122
return_values = list(result.return_values.values())
110-
111-
failed = any(isinstance(v, WorkerException) for v in return_values)
123+
failed = any(isinstance(v, ExceptionFromWorker) for v in return_values)
112124
state = "failed" if failed else "done"
113125

114126
return cls(

0 commit comments

Comments
 (0)