Skip to content

Commit 017dfdb

Browse files
committed
trigger profiling on abort
Summary: record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
1 parent d880de2 commit 017dfdb

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,13 @@ def train(self):
283283
self.checkpointer.load(step=job_config.checkpoint.load_step)
284284
logger.info(f"Training starts at step {self.step + 1}.")
285285

286+
torch_profiler = maybe_enable_profiling(
287+
job_config.profiling,
288+
global_step=self.step,
289+
base_folder=job_config.job.dump_folder,
290+
)
291+
286292
with (
287-
maybe_enable_profiling(
288-
job_config.profiling,
289-
global_step=self.step,
290-
base_folder=job_config.job.dump_folder,
291-
) as torch_profiler,
292293
maybe_enable_memory_snapshot(
293294
job_config.profiling,
294295
global_step=self.step,

torchtitan/tools/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
1919

2020

21-
@contextlib.contextmanager
2221
def maybe_enable_profiling(
2322
profiling_config: ProfilingConfig,
2423
*,
@@ -68,20 +67,20 @@ def trace_handler(prof):
6867
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
6968
elif torch.xpu.is_available():
7069
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
71-
with torch.profiler.profile(
70+
torch_profiler = torch.profiler.profile(
7271
activities=[
7372
torch.profiler.ProfilerActivity.CPU,
7473
gpu_device_profiled,
7574
],
7675
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7776
on_trace_ready=trace_handler,
7877
record_shapes=True,
79-
) as torch_profiler:
80-
torch_profiler.step_num = global_step
81-
yield torch_profiler
78+
)
79+
torch_profiler.step_num = global_step
80+
torch_profiler.start()
81+
return torch_profiler
8282
else:
83-
torch_profiler = contextlib.nullcontext()
84-
yield None
83+
return None
8584

8685

8786
@contextlib.contextmanager

torchtitan/train.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
78
import importlib
89
import os
10+
import signal
911
import time
1012
from datetime import timedelta
1113
from typing import Any, Generator, Iterable, Optional
@@ -32,8 +34,12 @@
3234
maybe_enable_profiling,
3335
)
3436

37+
c_globals = ctypes.CDLL(None) # POSIX
38+
3539

3640
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
41+
torch_profiler: torch.profiler.profile | None = None
42+
3743
# core configs
3844
job_config: JobConfig
3945
parallel_dims: ParallelDims
@@ -569,13 +575,14 @@ def train(self):
569575
if not self.ft_manager.enabled
570576
else f"replica_{self.ft_manager.replica_id}"
571577
)
578+
self.torch_profiler = maybe_enable_profiling(
579+
job_config.profiling,
580+
global_step=self.step,
581+
base_folder=job_config.job.dump_folder,
582+
leaf_folder=leaf_folder,
583+
)
584+
572585
with (
573-
maybe_enable_profiling(
574-
job_config.profiling,
575-
global_step=self.step,
576-
base_folder=job_config.job.dump_folder,
577-
leaf_folder=leaf_folder,
578-
) as torch_profiler,
579586
maybe_enable_memory_snapshot(
580587
job_config.profiling,
581588
global_step=self.step,
@@ -599,6 +606,16 @@ def train(self):
599606
),
600607
),
601608
):
609+
if self.torch_profiler:
610+
611+
@ctypes.CFUNCTYPE(None, ctypes.c_int)
612+
def sigabrt_handler(signal):
613+
logger.info("SIGABRT received. Stopping profiler")
614+
self.torch_profiler.profiler.enabled = False
615+
self.torch_profiler.export_chrome_trace("trace.json")
616+
617+
c_globals.signal(signal.SIGABRT, sigabrt_handler)
618+
602619
data_iterator = self.batch_generator(self.dataloader)
603620
while self.should_continue_training():
604621
self.step += 1
@@ -622,8 +639,8 @@ def train(self):
622639
self.validator.validate(self.model_parts, self.step)
623640

624641
# signal the profiler that the next profiling step has started
625-
if torch_profiler:
626-
torch_profiler.step()
642+
if self.torch_profiler:
643+
self.torch_profiler.step()
627644
if memory_profiler:
628645
memory_profiler.step()
629646

@@ -681,10 +698,12 @@ def close(self) -> None:
681698
else:
682699
trainer.train()
683700
except Exception:
701+
logger.info("Torchtitan training threw an exception")
684702
if trainer:
685703
trainer.close()
686704
raise
687705
else:
706+
logger.info("Torchtitan training completed")
688707
trainer.close()
689708
torch.distributed.destroy_process_group()
690709
logger.info("Process group destroyed")

0 commit comments

Comments
 (0)