Skip to content

Commit a8e24ed

Browse files
committed
trigger profiling on abort
Summary: record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
1 parent e3bb189 commit a8e24ed

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def train(self):
281281
self.checkpointer.load(step=job_config.checkpoint.load_step)
282282
logger.info(f"Training starts at step {self.step + 1}.")
283283

284+
torch_profiler = maybe_enable_profiling(
285+
job_config.profiling,
286+
global_step=self.step,
287+
base_folder=job_config.job.dump_folder,
288+
)
289+
284290
with (
285-
maybe_enable_profiling(
286-
job_config.profiling,
287-
global_step=self.step,
288-
base_folder=job_config.job.dump_folder,
289-
) as torch_profiler,
290291
maybe_enable_memory_snapshot(
291292
job_config.profiling,
292293
global_step=self.step,

torchtitan/tools/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
1919

2020

21-
@contextlib.contextmanager
2221
def maybe_enable_profiling(
2322
profiling_config: ProfilingConfig,
2423
*,
@@ -68,20 +67,20 @@ def trace_handler(prof):
6867
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
6968
elif torch.xpu.is_available():
7069
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
71-
with torch.profiler.profile(
70+
torch_profiler = torch.profiler.profile(
7271
activities=[
7372
torch.profiler.ProfilerActivity.CPU,
7473
gpu_device_profiled,
7574
],
7675
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7776
on_trace_ready=trace_handler,
7877
record_shapes=True,
79-
) as torch_profiler:
80-
torch_profiler.step_num = global_step
81-
yield torch_profiler
78+
)
79+
torch_profiler.step_num = global_step
80+
torch_profiler.start()
81+
return torch_profiler
8282
else:
83-
torch_profiler = contextlib.nullcontext()
84-
yield None
83+
return None
8584

8685

8786
@contextlib.contextmanager

torchtitan/train.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
78
import importlib
89
import os
10+
import signal
911
import time
1012
from datetime import timedelta
1113
from typing import Any, Generator, Iterable, Optional
@@ -32,8 +34,12 @@
3234
maybe_enable_profiling,
3335
)
3436

37+
c_globals = ctypes.CDLL(None) # POSIX
38+
3539

3640
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
41+
torch_profiler: torch.profiler.profile | None = None
42+
3743
# core configs
3844
job_config: JobConfig
3945
parallel_dims: ParallelDims
@@ -572,13 +578,14 @@ def train(self):
572578
if not self.ft_manager.enabled
573579
else f"replica_{self.ft_manager.replica_id}"
574580
)
581+
self.torch_profiler = maybe_enable_profiling(
582+
job_config.profiling,
583+
global_step=self.step,
584+
base_folder=job_config.job.dump_folder,
585+
leaf_folder=leaf_folder,
586+
)
587+
575588
with (
576-
maybe_enable_profiling(
577-
job_config.profiling,
578-
global_step=self.step,
579-
base_folder=job_config.job.dump_folder,
580-
leaf_folder=leaf_folder,
581-
) as torch_profiler,
582589
maybe_enable_memory_snapshot(
583590
job_config.profiling,
584591
global_step=self.step,
@@ -602,6 +609,15 @@ def train(self):
602609
),
603610
),
604611
):
612+
if self.torch_profiler:
613+
614+
@ctypes.CFUNCTYPE(None, ctypes.c_int)
615+
def sigabrt_handler(signal):
616+
logger.info("SIGABRT received. Stopping profiler")
617+
self.torch_profiler.export_chrome_trace("trace.json")
618+
619+
c_globals.signal(signal.SIGABRT, sigabrt_handler)
620+
605621
data_iterator = self.batch_generator(self.dataloader)
606622
while self.should_continue_training():
607623
self.step += 1
@@ -625,8 +641,8 @@ def train(self):
625641
self.validator.validate(self.model_parts, self.step)
626642

627643
# signal the profiler that the next profiling step has started
628-
if torch_profiler:
629-
torch_profiler.step()
644+
if self.torch_profiler:
645+
self.torch_profiler.step()
630646
if memory_profiler:
631647
memory_profiler.step()
632648

@@ -684,10 +700,12 @@ def close(self) -> None:
684700
else:
685701
trainer.train()
686702
except Exception:
703+
logger.info("Torchtitan training threw an exception")
687704
if trainer:
688705
trainer.close()
689706
raise
690707
else:
708+
logger.info("Torchtitan training completed")
691709
trainer.close()
692710
torch.distributed.destroy_process_group()
693711
logger.info("Process group destroyed")

0 commit comments

Comments
 (0)