diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eb477941ca..138672c739 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,20 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """ + The steps profiler is active for. + + This is used to configure torch.profile.schedule. + """ + + profiler_warmup: int = 3 + """ + The number of warmup steps before the active step in each profiling cycle. + + This is used to configure torch.profile.schedule. + """ + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335a..f398dba9b5 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -34,7 +31,11 @@ def maybe_enable_profiling( if enable_profiling: trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder) - profile_freq = profiling_config.profile_freq + profile_freq, warmup, active = ( + profiling_config.profile_freq, + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) rank = torch.distributed.get_rank() @@ -58,7 +59,6 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 wait = profile_freq - (active + warmup) assert ( wait >= 0