Skip to content

Commit e1b5016

Browse files
committed
trigger profiling on abort
Summary: record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
1 parent 9333989 commit e1b5016

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,13 @@ def train(self):
277277
self.checkpointer.load(step=job_config.checkpoint.load_step)
278278
logger.info(f"Training starts at step {self.step + 1}.")
279279

280+
torch_profiler = maybe_enable_profiling(
281+
job_config.profiling,
282+
global_step=self.step,
283+
base_folder=job_config.job.dump_folder,
284+
)
285+
280286
with (
281-
maybe_enable_profiling(
282-
job_config.profiling,
283-
global_step=self.step,
284-
base_folder=job_config.job.dump_folder,
285-
) as torch_profiler,
286287
maybe_enable_memory_snapshot(
287288
job_config.profiling,
288289
global_step=self.step,

torchtitan/tools/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
1919

2020

21-
@contextlib.contextmanager
2221
def maybe_enable_profiling(
2322
profiling_config: ProfilingConfig,
2423
*,
@@ -68,20 +67,20 @@ def trace_handler(prof):
6867
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
6968
elif torch.xpu.is_available():
7069
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
71-
with torch.profiler.profile(
70+
torch_profiler = torch.profiler.profile(
7271
activities=[
7372
torch.profiler.ProfilerActivity.CPU,
7473
gpu_device_profiled,
7574
],
7675
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7776
on_trace_ready=trace_handler,
7877
record_shapes=True,
79-
) as torch_profiler:
80-
torch_profiler.step_num = global_step
81-
yield torch_profiler
78+
)
79+
torch_profiler.step_num = global_step
80+
torch_profiler.start()
81+
return torch_profiler
8282
else:
83-
torch_profiler = contextlib.nullcontext()
84-
yield None
83+
return None
8584

8685

8786
@contextlib.contextmanager

torchtitan/train.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
78
import importlib
89
import os
10+
import signal
911
import time
1012
from datetime import timedelta
1113
from typing import Any, Generator, Iterable, Optional
@@ -33,8 +35,12 @@
3335
maybe_enable_profiling,
3436
)
3537

38+
c_globals = ctypes.CDLL(None) # POSIX
39+
3640

3741
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
42+
torch_profiler: torch.profiler.profile | None = None
43+
3844
# core configs
3945
job_config: JobConfig
4046
parallel_dims: ParallelDims
@@ -555,13 +561,14 @@ def train(self):
555561
if not self.ft_manager.enabled
556562
else f"replica_{self.ft_manager.replica_id}"
557563
)
564+
self.torch_profiler = maybe_enable_profiling(
565+
job_config.profiling,
566+
global_step=self.step,
567+
base_folder=job_config.job.dump_folder,
568+
leaf_folder=leaf_folder,
569+
)
570+
558571
with (
559-
maybe_enable_profiling(
560-
job_config.profiling,
561-
global_step=self.step,
562-
base_folder=job_config.job.dump_folder,
563-
leaf_folder=leaf_folder,
564-
) as torch_profiler,
565572
maybe_enable_memory_snapshot(
566573
job_config.profiling,
567574
global_step=self.step,
@@ -585,6 +592,16 @@ def train(self):
585592
),
586593
),
587594
):
595+
if self.torch_profiler:
596+
597+
@ctypes.CFUNCTYPE(None, ctypes.c_int)
598+
def sigabrt_handler(signal):
599+
logger.info("SIGABRT received. Stopping profiler")
600+
self.torch_profiler.profiler.enabled = False
601+
self.torch_profiler.export_chrome_trace("trace.json")
602+
603+
c_globals.signal(signal.SIGABRT, sigabrt_handler)
604+
588605
data_iterator = self.batch_generator(self.dataloader)
589606
while self.should_continue_training():
590607
self.step += 1
@@ -608,8 +625,8 @@ def train(self):
608625
self.validator.validate(self.model_parts, self.step)
609626

610627
# signal the profiler that the next profiling step has started
611-
if torch_profiler:
612-
torch_profiler.step()
628+
if self.torch_profiler:
629+
self.torch_profiler.step()
613630
if memory_profiler:
614631
memory_profiler.step()
615632

@@ -667,10 +684,12 @@ def close(self) -> None:
667684
else:
668685
trainer.train()
669686
except Exception:
687+
logger.info("Torchtitan training threw an exception")
670688
if trainer:
671689
trainer.close()
672690
raise
673691
else:
692+
logger.info("Torchtitan training completed")
674693
trainer.close()
675694
torch.distributed.destroy_process_group()
676695
logger.info("Process group destroyed")

0 commit comments

Comments
 (0)