Skip to content

Commit 424b23c

Browse files
committed
improve profiler
1 parent 82f0287 commit 424b23c

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

torchtitan/config/job_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class Profiling:
3434
profile_freq: int = 10
3535
"""How often to collect profile traces, in iterations"""
3636

37+
profiler_active: int = 1
38+
"""The steps profiler is active for"""
39+
40+
profiler_warmup: int = 3
41+
"""The number of warmup steps before the active step in each profiling cycle"""
42+
3743
enable_memory_snapshot: bool = False
3844
"""Whether to dump memory snapshot"""
3945

torchtitan/tools/profiling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from torchtitan.config import Profiling as ProfilingConfig
1515
from torchtitan.tools.logging import logger
1616

17-
# the number of warmup steps before the active step in each profiling cycle
18-
WARMUP = 3
19-
2017
# how much memory allocation/free ops to record in memory snapshots
2118
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
2219

@@ -58,7 +55,10 @@ def trace_handler(prof):
5855
if not os.path.exists(trace_dir):
5956
os.makedirs(trace_dir, exist_ok=True)
6057

61-
warmup, active = WARMUP, 1
58+
warmup, active = (
59+
profiling_config.profiler_warmup,
60+
profiling_config.profiler_active,
61+
)
6262
wait = profile_freq - (active + warmup)
6363
assert (
6464
wait >= 0

torchtitan/train.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import importlib
88
import os
9+
import signal
910
import time
1011
from datetime import timedelta
1112
from typing import Any, Generator, Iterable, Optional
@@ -588,6 +589,17 @@ def train(self):
588589
),
589590
),
590591
):
592+
if torch_profiler:
593+
594+
def sigabrt_handler(signal, frame):
595+
logger.info("SIGABRT received. Stopping profiler")
596+
for _ in range(config.profiling.profiler_active):
597+
# Step the profiler enough times to trigger a trace
598+
torch_profiler.step()
599+
torch_profiler.stop()
600+
601+
signal.signal(signal.SIGABRT, sigabrt_handler)
602+
591603
data_iterator = self.batch_generator(self.dataloader)
592604
while self.should_continue_training():
593605
self.step += 1

0 commit comments

Comments
 (0)