diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 0eb06a0d4..14a5218a4 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,12 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """The steps profiler is active for""" + + profiler_warmup: int = 3 + """The number of warmup steps before the active step in each profiling cycle""" + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335..97c9272e2 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -58,7 +55,10 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 + warmup, active = ( + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) wait = profile_freq - (active + warmup) assert ( wait >= 0 diff --git a/torchtitan/train.py b/torchtitan/train.py index 0afcac8dc..7f2ceceb7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -6,6 +6,7 @@ import importlib import os +import signal import time from datetime import timedelta from typing import Any, Generator, Iterable, Optional @@ -588,6 +589,17 @@ def train(self): ), ), ): + if torch_profiler: + + def sigabrt_handler(signal, frame): + logger.info("SIGABRT received. Stopping profiler") + for _ in range(config.profiling.profiler_active): + # Step the profiler enough times to trigger a trace + torch_profiler.step() + torch_profiler.stop() + + signal.signal(signal.SIGABRT, sigabrt_handler) + data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1