Skip to content

Commit cd260f5

Browse files
committed
updates to logging
1 parent 2b76c49 commit cd260f5

File tree

7 files changed

+51
-65
lines changed

7 files changed

+51
-65
lines changed

src/torchrunx/agent.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
LauncherAgentGroup,
2020
get_open_port,
2121
)
22-
from .utils.logging import log_records_to_socket, redirect_stdio_to_logger
22+
from .utils.logs import log_records_to_socket, redirect_stdio_to_logger
2323
from .worker import WorkerArgs, worker_entrypoint
2424

2525

@@ -48,23 +48,19 @@ def main(
4848
hostname: Hostname of this agent.
4949
"""
5050
# Stream logs to logging server
51-
52-
logger = logging.getLogger()
53-
redirect_stdio_to_logger(logger)
51+
logger = logging.getLogger(f"{__package__}.{hostname}")
5452

5553
log_records_to_socket(
56-
logger=logger,
57-
hostname=hostname,
58-
local_rank=None,
59-
logger_hostname=logger_hostname,
60-
logger_port=logger_port,
54+
hostname=hostname, local_rank=None, logger_hostname=logger_hostname, logger_port=logger_port
6155
)
6256

63-
logging.debug("Agent logging setup.")
57+
redirect_stdio_to_logger(logger)
58+
59+
logger.debug("Agent logging setup.")
6460

6561
# Set up launcher-agent group
6662

67-
logging.debug("Initializing launcher-agent group.")
63+
logger.debug("Initializing launcher-agent group.")
6864

6965
launcher_agent_group = LauncherAgentGroup(
7066
launcher_hostname=launcher_hostname,
@@ -77,7 +73,7 @@ def main(
7773

7874
# Communicate initial payloads between launcher/agents
7975

80-
logging.debug("Sending agent details to launcher.")
76+
logger.debug("Sending agent details to launcher.")
8177

8278
payload = AgentPayload(
8379
hostname=socket.getfqdn(),
@@ -86,7 +82,6 @@ def main(
8682
)
8783

8884
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
89-
main_agent_payload = agent_payloads[0]
9085

9186
hostname = launcher_payload.hostnames[agent_rank]
9287
worker_world_size = launcher_payload.worker_world_size
@@ -95,7 +90,7 @@ def main(
9590

9691
# Spawn worker processes
9792

98-
logging.debug("Launching worker processes.")
93+
logger.debug("Launching worker processes.")
9994

10095
ctx = dist_mp.start_processes(
10196
name=f"{hostname}_",
@@ -106,8 +101,8 @@ def main(
106101
function=launcher_payload.fn,
107102
logger_hostname=logger_hostname,
108103
logger_port=logger_port,
109-
main_agent_hostname=main_agent_payload.hostname,
110-
main_agent_port=main_agent_payload.port,
104+
master_hostname=agent_payloads[0].hostname,
105+
master_port=agent_payloads[0].port,
111106
backend=launcher_payload.backend,
112107
rank=worker_global_ranks[i],
113108
local_rank=i,
@@ -146,12 +141,12 @@ def main(
146141
all_done = all(s.state == "done" for s in agent_statuses)
147142
any_failed = any(s.state == "failed" for s in agent_statuses)
148143
if all_done or any_failed:
149-
logging.debug("Workers exiting %s.", "cleanly" if not any_failed else "with errors")
144+
logger.debug("Workers exiting %s.", "cleanly" if not any_failed else "with errors")
150145
break
151146
finally:
152147
ctx.close()
153148
sys.stdout.flush()
154149
sys.stderr.flush()
155150
launcher_agent_group.shutdown()
156151

157-
logging.debug("Agent exiting.")
152+
logger.debug("Agent exiting.")

src/torchrunx/launcher.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
resolve_environment,
3030
)
3131
from .utils.errors import ExceptionFromWorker, WorkerFailedError
32-
from .utils.logging import LoggingServerArgs, default_handlers, start_logging_server
32+
from .utils.logs import LoggingServerArgs, default_handlers, start_logging_server
3333

3434
DEFAULT_ENV_VARS_FOR_COPY = (
3535
"PATH",
@@ -45,8 +45,6 @@
4545
FunctionP = ParamSpec("FunctionP")
4646
FunctionR = TypeVar("FunctionR")
4747

48-
logger = logging.getLogger(__name__)
49-
5048

5149
@dataclass
5250
class Launcher:
@@ -104,6 +102,8 @@ def run( # noqa: C901, PLR0912, PLR0915
104102
WorkerFailedError: If a worker fails (e.g. from a segmentation fault).
105103
AgentFailedError: If an agent fails, e.g. from an OS signal.
106104
"""
105+
logger = logging.getLogger(__package__)
106+
107107
if not dist.is_available():
108108
msg = "The torch.distributed package is not available."
109109
raise RuntimeError(msg)
@@ -249,11 +249,6 @@ def handler_factory() -> list[logging.Handler]:
249249
if log_process is not None:
250250
log_process.kill()
251251

252-
logger.debug("Killing launcher-agent group.")
253-
254-
if launcher_agent_group is not None:
255-
launcher_agent_group.shutdown()
256-
257252
# cleanup: SIGTERM all agents
258253
if agent_payloads is not None:
259254
for agent_payload, agent_hostname in zip(agent_payloads, hostnames):
@@ -265,6 +260,10 @@ def handler_factory() -> list[logging.Handler]:
265260
ssh_config_file=ssh_config_file,
266261
)
267262

263+
if launcher_agent_group is not None:
264+
logger.debug("Killing launcher-agent group.")
265+
launcher_agent_group.shutdown()
266+
268267

269268
@dataclass
270269
class LaunchResult(Generic[FunctionR]):

src/torchrunx/utils/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from .logging import add_filter_to_handler, file_handler, stream_handler
2-
3-
__all__ = ["add_filter_to_handler", "file_handler", "stream_handler"]

src/torchrunx/utils/comm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class LauncherPayload:
119119
hostnames: list[str]
120120
worker_global_ranks: list[list[int]]
121121
worker_world_size: int
122-
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
122+
backend: Literal["nccl", "gloo", "mpi", "ucc"] | None
123123
timeout: int
124124

125125

src/torchrunx/utils/logging.py renamed to src/torchrunx/utils/logs.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Utilities for intercepting logs in worker processes and handling these in the Launcher.""" # noqa: A005
1+
"""Utilities for intercepting logs in worker processes and handling these in the Launcher."""
22

33
from __future__ import annotations
44

@@ -24,7 +24,7 @@
2424
from contextlib import redirect_stderr, redirect_stdout
2525
from dataclasses import dataclass
2626
from io import StringIO
27-
from logging import Handler, Logger
27+
from logging import Handler, Logger, LogRecord
2828
from logging.handlers import SocketHandler
2929
from multiprocessing.synchronize import Event as EventClass
3030
from pathlib import Path
@@ -55,40 +55,40 @@ def _filter(record: WorkerLogRecord) -> bool:
5555
handler.addFilter(_filter) # pyright: ignore [reportArgumentType]
5656

5757

58-
def default_handlers(
59-
hostnames: list[str],
60-
workers_per_host: list[int],
61-
log_level: int = logging.INFO,
62-
) -> list[logging.Handler]:
58+
def default_handlers(hostnames: list[str], workers_per_host: list[int]) -> list[logging.Handler]:
6359
"""Constructs default :obj:`logging.Handler` objects.
6460
6561
Logs for the rank 0 agent and worker are written to launcher process stdout.
6662
Logs for all hosts/workers are written to files in ``$TORCHRUNX_LOG_DIR`` (named by timestamp,
6763
hostname, local_rank).
6864
"""
6965
log_dir = Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs"))
70-
log_level = logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")] # noqa: SLF001
66+
file_log_level = logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")] # noqa: SLF001
67+
7168
return [
72-
stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level),
73-
stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level),
74-
*file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level),
69+
RedirectHandler(hostname=hostnames[0], local_rank=None),
70+
RedirectHandler(hostname=hostnames[0], local_rank=0),
71+
*file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=file_log_level),
7572
]
7673

7774

75+
class RedirectHandler(logging.Handler):
76+
def __init__(self, hostname: str, local_rank: int | None) -> None:
77+
super().__init__()
78+
add_filter_to_handler(self, hostname=hostname, local_rank=local_rank)
79+
80+
def emit(self, record: LogRecord) -> None:
81+
logger = logging.getLogger(record.name)
82+
if logger.isEnabledFor(record.levelno):
83+
logger.handle(record)
84+
85+
7886
def stream_handler(
7987
hostname: str, local_rank: int | None, log_level: int = logging.NOTSET
8088
) -> logging.Handler:
8189
"""Handler builder function for writing logs from specified hostname/rank to stdout."""
8290
handler = logging.StreamHandler(stream=sys.stdout)
8391
add_filter_to_handler(handler, hostname, local_rank, log_level=log_level)
84-
handler.setFormatter(
85-
logging.Formatter(
86-
"%(asctime)s:%(levelname)s:%(hostname)s[%(local_rank)s]: %(message)s"
87-
if local_rank is not None
88-
else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s",
89-
datefmt="%Y-%m-%d %H:%M:%S",
90-
),
91-
)
9292
return handler
9393

9494

@@ -259,14 +259,13 @@ def from_record(cls, record: logging.LogRecord, hostname: str, local_rank: int |
259259

260260

261261
def log_records_to_socket(
262-
logger: Logger,
263262
hostname: str,
264263
local_rank: int | None, # None indicates agent
265264
logger_hostname: str,
266265
logger_port: int,
267266
) -> None:
268267
"""Encode LogRecords with hostname/local_rank. Send to TCP socket on Launcher."""
269-
logger.setLevel(logging.NOTSET)
268+
logging.root.setLevel(logging.NOTSET)
270269

271270
old_factory = logging.getLogRecordFactory()
272271

@@ -276,4 +275,4 @@ def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003
276275

277276
logging.setLogRecordFactory(record_factory)
278277

279-
logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port))
278+
logging.root.addHandler(SocketHandler(host=logger_hostname, port=logger_port))

src/torchrunx/worker.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from typing import Any, Callable, Literal
1212

1313
import cloudpickle
14-
import torch
1514
import torch.distributed as dist
1615
from typing_extensions import Self
1716

1817
from .utils.errors import ExceptionFromWorker
19-
from .utils.logging import log_records_to_socket, redirect_stdio_to_logger
18+
from .utils.logs import log_records_to_socket, redirect_stdio_to_logger
2019

2120
__all__ = ["WorkerArgs", "worker_entrypoint"]
2221

@@ -28,9 +27,9 @@ class WorkerArgs:
2827
function: Callable
2928
logger_hostname: str
3029
logger_port: int
31-
main_agent_hostname: str
32-
main_agent_port: int
33-
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
30+
master_hostname: str
31+
master_port: int
32+
backend: Literal["nccl", "gloo", "mpi", "ucc"] | None
3433
rank: int
3534
local_rank: int
3635
node_rank: int
@@ -60,10 +59,9 @@ def worker_entrypoint(serialized_worker_args: bytes) -> Any | ExceptionFromWorke
6059

6160
# Start logging to the logging server (i.e. the launcher)
6261

63-
logger = logging.getLogger()
62+
logger = logging.getLogger(f"{__package__}.{worker_args.hostname}.{worker_args.local_rank}")
6463

6564
log_records_to_socket(
66-
logger=logger,
6765
hostname=worker_args.hostname,
6866
local_rank=worker_args.local_rank,
6967
logger_hostname=worker_args.logger_hostname,
@@ -79,23 +77,21 @@ def worker_entrypoint(serialized_worker_args: bytes) -> Any | ExceptionFromWorke
7977
os.environ["GROUP_RANK"] = str(worker_args.node_rank)
8078
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
8179
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
82-
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
83-
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
80+
os.environ["MASTER_ADDR"] = worker_args.master_hostname
81+
os.environ["MASTER_PORT"] = str(worker_args.master_port)
8482

8583
# Prepare the process group (e.g. for communication within the user's function)
8684

8785
if worker_args.backend is not None:
8886
backend = worker_args.backend
89-
if backend == "auto":
90-
backend = "nccl" if torch.cuda.is_available() else "gloo"
9187

9288
dist.init_process_group(
9389
backend=backend,
9490
world_size=worker_args.world_size,
9591
rank=worker_args.rank,
9692
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
97-
host_name=worker_args.main_agent_hostname,
98-
port=worker_args.main_agent_port,
93+
host_name=worker_args.master_hostname,
94+
port=worker_args.master_port,
9995
world_size=worker_args.world_size,
10096
is_master=(worker_args.rank == 0),
10197
),

0 commit comments

Comments
 (0)