Skip to content

Commit 10c0514

Browse files
committed
fix for logging server serialization problems
1 parent 677edcb commit 10c0514

File tree

3 files changed

+77
-41
lines changed

3 files changed

+77
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "torchrunx"
7-
version = "0.2.0"
7+
version = "0.2.1"
88
authors = [
99
{name = "Apoorv Khandelwal", email = "[email protected]"},
1010
{name = "Peter Curtin", email = "[email protected]"},

src/torchrunx/launcher.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass
1717
from functools import partial, reduce
1818
from logging import Handler
19-
from multiprocessing import Process
19+
from multiprocessing import Event, Process
2020
from operator import add
2121
from pathlib import Path
2222
from typing import Any, Callable, Literal
@@ -34,7 +34,7 @@
3434
ExceptionFromWorker,
3535
WorkerFailedError,
3636
)
37-
from .utils.logging import LogRecordSocketReceiver, default_handlers
37+
from .utils.logging import LoggingServerArgs, start_logging_server
3838

3939

4040
@dataclass
@@ -76,27 +76,32 @@ def run( # noqa: C901, PLR0912
7676

7777
launcher_hostname = socket.getfqdn()
7878
launcher_port = get_open_port()
79+
logging_port = get_open_port()
7980
world_size = len(hostnames) + 1
8081

81-
log_receiver = None
82+
stop_logging_event = None
8283
log_process = None
8384
launcher_agent_group = None
8485
agent_payloads = None
8586

8687
try:
8788
# Start logging server (recieves LogRecords from agents/workers)
8889

89-
log_receiver = _build_logging_server(
90+
logging_server_args = LoggingServerArgs(
9091
log_handlers=log_handlers,
91-
launcher_hostname=launcher_hostname,
92+
logging_hostname=launcher_hostname,
93+
logging_port=logging_port,
9294
hostnames=hostnames,
9395
workers_per_host=workers_per_host,
9496
log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")),
9597
log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001
9698
)
9799

100+
stop_logging_event = Event()
101+
98102
log_process = Process(
99-
target=log_receiver.serve_forever,
103+
target=start_logging_server,
104+
args=(logging_server_args.serialize(), stop_logging_event),
100105
daemon=True,
101106
)
102107

@@ -109,7 +114,7 @@ def run( # noqa: C901, PLR0912
109114
command=_build_launch_command(
110115
launcher_hostname=launcher_hostname,
111116
launcher_port=launcher_port,
112-
logger_port=log_receiver.port,
117+
logger_port=logging_port,
113118
world_size=world_size,
114119
rank=i + 1,
115120
env_vars=(self.default_env_vars + self.extra_env_vars),
@@ -166,11 +171,10 @@ def run( # noqa: C901, PLR0912
166171
if all(s.state == "done" for s in agent_statuses):
167172
break
168173
finally:
169-
if log_receiver is not None:
170-
log_receiver.shutdown()
171-
if log_process is not None:
172-
log_receiver.server_close()
173-
log_process.kill()
174+
if stop_logging_event is not None:
175+
stop_logging_event.set()
176+
if log_process is not None:
177+
log_process.kill()
174178

175179
if launcher_agent_group is not None:
176180
launcher_agent_group.shutdown()
@@ -307,31 +311,6 @@ def _resolve_workers_per_host(
307311
return workers_per_host
308312

309313

310-
def _build_logging_server(
311-
log_handlers: list[Handler] | Literal["auto"] | None,
312-
launcher_hostname: str,
313-
hostnames: list[str],
314-
workers_per_host: list[int],
315-
log_dir: str | os.PathLike,
316-
log_level: int,
317-
) -> LogRecordSocketReceiver:
318-
if log_handlers is None:
319-
log_handlers = []
320-
elif log_handlers == "auto":
321-
log_handlers = default_handlers(
322-
hostnames=hostnames,
323-
workers_per_host=workers_per_host,
324-
log_dir=log_dir,
325-
log_level=log_level,
326-
)
327-
328-
return LogRecordSocketReceiver(
329-
host=launcher_hostname,
330-
port=get_open_port(),
331-
handlers=log_handlers,
332-
)
333-
334-
335314
def _build_launch_command(
336315
launcher_hostname: str,
337316
launcher_port: int,

src/torchrunx/utils/logging.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import annotations
44

55
__all__ = [
6-
"LogRecordSocketReceiver",
6+
"LoggingServerArgs",
7+
"start_logging_server",
78
"redirect_stdio_to_logger",
89
"log_records_to_socket",
910
"add_filter_to_handler",
@@ -25,12 +26,14 @@
2526
from logging.handlers import SocketHandler
2627
from pathlib import Path
2728
from socketserver import StreamRequestHandler, ThreadingTCPServer
28-
from typing import TYPE_CHECKING
29+
from typing import TYPE_CHECKING, Literal
2930

31+
import cloudpickle
3032
from typing_extensions import Self
3133

3234
if TYPE_CHECKING:
3335
import os
36+
from multiprocessing.synchronize import Event as EventClass
3437

3538
## Handler utilities
3639

@@ -139,7 +142,7 @@ def default_handlers(
139142
## Launcher utilities
140143

141144

142-
class LogRecordSocketReceiver(ThreadingTCPServer):
145+
class _LogRecordSocketReceiver(ThreadingTCPServer):
143146
"""TCP server for recieving Agent/Worker log records in Launcher.
144147
145148
Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process).
@@ -180,6 +183,60 @@ def shutdown(self) -> None:
180183
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]
181184

182185

186+
@dataclass
187+
class LoggingServerArgs:
188+
log_handlers: list[Handler] | Literal["auto"] | None
189+
logging_hostname: str
190+
logging_port: int
191+
hostnames: list[str]
192+
workers_per_host: list[int]
193+
log_dir: str | os.PathLike
194+
log_level: int
195+
196+
def serialize(self) -> SerializedLoggingServerArgs:
197+
return SerializedLoggingServerArgs(args=self)
198+
199+
200+
class SerializedLoggingServerArgs:
201+
def __init__(self, args: LoggingServerArgs) -> None:
202+
self.bytes = cloudpickle.dumps(args)
203+
204+
def deserialize(self) -> LoggingServerArgs:
205+
return cloudpickle.loads(self.bytes)
206+
207+
208+
def start_logging_server(
209+
serialized_args: SerializedLoggingServerArgs,
210+
stop_event: EventClass,
211+
) -> None:
212+
args: LoggingServerArgs = serialized_args.deserialize()
213+
214+
log_handlers = args.log_handlers
215+
if log_handlers is None:
216+
log_handlers = []
217+
elif log_handlers == "auto":
218+
log_handlers = default_handlers(
219+
hostnames=args.hostnames,
220+
workers_per_host=args.workers_per_host,
221+
log_dir=args.log_dir,
222+
log_level=args.log_level,
223+
)
224+
225+
log_receiver = _LogRecordSocketReceiver(
226+
host=args.logging_hostname,
227+
port=args.logging_port,
228+
handlers=log_handlers,
229+
)
230+
231+
log_receiver.serve_forever()
232+
233+
while not stop_event.is_set():
234+
pass
235+
236+
log_receiver.shutdown()
237+
log_receiver.server_close()
238+
239+
183240
## Agent/worker utilities
184241

185242

0 commit comments

Comments
 (0)