Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""The steps profiler is active for"""

profiler_warmup: int = 3
"""The number of warmup steps before the active step in each profiling cycle"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000

Expand Down Expand Up @@ -58,7 +55,10 @@ def trace_handler(prof):
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
warmup, active = (
profiling_config.profiler_warmup,
profiling_config.profiler_active,
)
wait = profile_freq - (active + warmup)
assert (
wait >= 0
Expand Down
12 changes: 12 additions & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import importlib
import os
import signal
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
Expand Down Expand Up @@ -588,6 +589,17 @@ def train(self):
),
),
):
if torch_profiler:

def sigabrt_handler(signal, frame):
logger.info("SIGABRT received. Stopping profiler")
for _ in range(config.profiling.profiler_active):
# Step the profiler enough times to trigger a trace
torch_profiler.step()
torch_profiler.stop()

signal.signal(signal.SIGABRT, sigabrt_handler)

data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
Expand Down
Loading