Skip to content

Commit 21739fd

Browse files
authored
enhance profiler config (#1809)
Summary: allow users to specify the profiler schedule --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1809). * #1811 * #1810 * #1812 * __->__ #1809 Co-authored-by: Tushar Jain <[email protected]>
1 parent 41eff53 commit 21739fd

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

torchtitan/config/job_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ class Profiling:
3434
profile_freq: int = 10
3535
"""How often to collect profile traces, in iterations"""
3636

37+
profiler_active: int = 1
38+
"""
39+
The steps profiler is active for.
40+
41+
This is used to configure torch.profile.schedule.
42+
"""
43+
44+
profiler_warmup: int = 3
45+
"""
46+
The number of warmup steps before the active step in each profiling cycle.
47+
48+
This is used to configure torch.profile.schedule.
49+
"""
50+
3751
enable_memory_snapshot: bool = False
3852
"""Whether to dump memory snapshot"""
3953

torchtitan/tools/profiling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from torchtitan.config import Profiling as ProfilingConfig
1515
from torchtitan.tools.logging import logger
1616

17-
# the number of warmup steps before the active step in each profiling cycle
18-
WARMUP = 3
19-
2017
# how much memory allocation/free ops to record in memory snapshots
2118
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
2219

@@ -34,7 +31,11 @@ def maybe_enable_profiling(
3431

3532
if enable_profiling:
3633
trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder)
37-
profile_freq = profiling_config.profile_freq
34+
profile_freq, warmup, active = (
35+
profiling_config.profile_freq,
36+
profiling_config.profiler_warmup,
37+
profiling_config.profiler_active,
38+
)
3839

3940
rank = torch.distributed.get_rank()
4041

@@ -58,7 +59,6 @@ def trace_handler(prof):
5859
if not os.path.exists(trace_dir):
5960
os.makedirs(trace_dir, exist_ok=True)
6061

61-
warmup, active = WARMUP, 1
6262
wait = profile_freq - (active + warmup)
6363
assert (
6464
wait >= 0

0 commit comments

Comments
 (0)