Skip to content

Commit 182636a

Browse files
committed
improve profiler
1 parent 8d5f141 commit 182636a

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

torchtitan/config/job_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class Profiling:
3434
profile_freq: int = 10
3535
"""How often to collect profile traces, in iterations"""
3636

37+
profiler_active: int = 1
38+
"""The steps profiler is active for"""
39+
40+
profiler_warmup: int = 3
41+
"""The number of warmup steps before the active step in each profiling cycle"""
42+
3743
enable_memory_snapshot: bool = False
3844
"""Whether to dump memory snapshot"""
3945

torchtitan/tools/profiling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from torchtitan.config import Profiling as ProfilingConfig
1515
from torchtitan.tools.logging import logger
1616

17-
# the number of warmup steps before the active step in each profiling cycle
18-
WARMUP = 3
19-
2017
# how much memory allocation/free ops to record in memory snapshots
2118
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
2219

@@ -58,7 +55,7 @@ def trace_handler(prof):
5855
if not os.path.exists(trace_dir):
5956
os.makedirs(trace_dir, exist_ok=True)
6057

61-
warmup, active = WARMUP, 1
58+
warmup, active = profiling_config.profiler_warmup, profiling_config.profiler_active
6259
wait = profile_freq - (active + warmup)
6360
assert (
6461
wait >= 0

torchtitan/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
maybe_enable_memory_snapshot,
3333
maybe_enable_profiling,
3434
)
35+
import signal
3536

3637

3738
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
@@ -585,6 +586,15 @@ def train(self):
585586
),
586587
),
587588
):
589+
if torch_profiler:
590+
def sigabrt_handler(signal, frame):
591+
logger.info("SIGABRT received. Stopping profiler")
592+
for _ in range(config.profiling.profiler_active):
593+
# Step the profiler enough times to trigger a trace
594+
torch_profiler.step()
595+
torch_profiler.stop()
596+
signal.signal(signal.SIGABRT, sigabrt_handler)
597+
588598
data_iterator = self.batch_generator(self.dataloader)
589599
while self.should_continue_training():
590600
self.step += 1

0 commit comments

Comments
 (0)