Skip to content

Commit 4ba2fd0

Browse files
committed
trigger profiling on abort
Summary: record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
1 parent 116d407 commit 4ba2fd0

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def train(self):
281281
self.checkpointer.load(step=job_config.checkpoint.load_step)
282282
logger.info(f"Training starts at step {self.step + 1}.")
283283

284+
torch_profiler = maybe_enable_profiling(
285+
job_config.profiling,
286+
global_step=self.step,
287+
base_folder=job_config.job.dump_folder,
288+
)
289+
284290
with (
285-
maybe_enable_profiling(
286-
job_config.profiling,
287-
global_step=self.step,
288-
base_folder=job_config.job.dump_folder,
289-
) as torch_profiler,
290291
maybe_enable_memory_snapshot(
291292
job_config.profiling,
292293
global_step=self.step,

torchtitan/tools/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
1919

2020

21-
@contextlib.contextmanager
2221
def maybe_enable_profiling(
2322
profiling_config: ProfilingConfig,
2423
*,
@@ -68,20 +67,20 @@ def trace_handler(prof):
6867
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
6968
elif torch.xpu.is_available():
7069
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
71-
with torch.profiler.profile(
70+
torch_profiler = torch.profiler.profile(
7271
activities=[
7372
torch.profiler.ProfilerActivity.CPU,
7473
gpu_device_profiled,
7574
],
7675
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7776
on_trace_ready=trace_handler,
7877
record_shapes=True,
79-
) as torch_profiler:
80-
torch_profiler.step_num = global_step
81-
yield torch_profiler
78+
)
79+
torch_profiler.step_num = global_step
80+
torch_profiler.start()
81+
return torch_profiler
8282
else:
83-
torch_profiler = contextlib.nullcontext()
84-
yield None
83+
return None
8584

8685

8786
@contextlib.contextmanager

torchtitan/train.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
78
import importlib
89
import os
10+
import signal
911
import time
1012
from datetime import timedelta
1113
from typing import Any, Generator, Iterable, Optional
@@ -32,8 +34,12 @@
3234
maybe_enable_profiling,
3335
)
3436

37+
c_globals = ctypes.CDLL(None) # POSIX
38+
3539

3640
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
41+
torch_profiler: torch.profiler.profile | None = None
42+
3743
# core configs
3844
job_config: JobConfig
3945
parallel_dims: ParallelDims
@@ -582,13 +588,14 @@ def train(self):
582588
if not self.ft_manager.enabled
583589
else f"replica_{self.ft_manager.replica_id}"
584590
)
591+
self.torch_profiler = maybe_enable_profiling(
592+
job_config.profiling,
593+
global_step=self.step,
594+
base_folder=job_config.job.dump_folder,
595+
leaf_folder=leaf_folder,
596+
)
597+
585598
with (
586-
maybe_enable_profiling(
587-
job_config.profiling,
588-
global_step=self.step,
589-
base_folder=job_config.job.dump_folder,
590-
leaf_folder=leaf_folder,
591-
) as torch_profiler,
592599
maybe_enable_memory_snapshot(
593600
job_config.profiling,
594601
global_step=self.step,
@@ -612,6 +619,15 @@ def train(self):
612619
),
613620
),
614621
):
622+
if self.torch_profiler:
623+
624+
@ctypes.CFUNCTYPE(None, ctypes.c_int)
625+
def sigabrt_handler(signal):
626+
logger.info("SIGABRT received. Stopping profiler")
627+
self.torch_profiler.export_chrome_trace("trace.json")
628+
629+
c_globals.signal(signal.SIGABRT, sigabrt_handler)
630+
615631
data_iterator = self.batch_generator(self.dataloader)
616632
while self.should_continue_training():
617633
self.step += 1
@@ -635,8 +651,8 @@ def train(self):
635651
self.validator.validate(self.model_parts, self.step)
636652

637653
# signal the profiler that the next profiling step has started
638-
if torch_profiler:
639-
torch_profiler.step()
654+
if self.torch_profiler:
655+
self.torch_profiler.step()
640656
if memory_profiler:
641657
memory_profiler.step()
642658

@@ -694,10 +710,12 @@ def close(self) -> None:
694710
else:
695711
trainer.train()
696712
except Exception:
713+
logger.info("Torchtitan training threw an exception")
697714
if trainer:
698715
trainer.close()
699716
raise
700717
else:
718+
logger.info("Torchtitan training completed")
701719
trainer.close()
702720
torch.distributed.destroy_process_group()
703721
logger.info("Process group destroyed")

0 commit comments

Comments
 (0)