File tree Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,12 @@ class Profiling:
3434 profile_freq : int = 10
3535 """How often to collect profile traces, in iterations"""
3636
37+ profiler_active : int = 1
38+ """The steps profiler is active for"""
39+
40+ profiler_warmup : int = 3
41+ """The number of warmup steps before the active step in each profiling cycle"""
42+
3743 enable_memory_snapshot : bool = False
3844 """Whether to dump memory snapshot"""
3945
Original file line number Diff line number Diff line change 1414from torchtitan .config import Profiling as ProfilingConfig
1515from torchtitan .tools .logging import logger
1616
17- # the number of warmup steps before the active step in each profiling cycle
18- WARMUP = 3
19-
2017# how much memory allocation/free ops to record in memory snapshots
2118MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
2219
@@ -58,7 +55,7 @@ def trace_handler(prof):
5855 if not os .path .exists (trace_dir ):
5956 os .makedirs (trace_dir , exist_ok = True )
6057
61- warmup , active = WARMUP , 1
58+ warmup , active = profiling_config . profiler_warmup , profiling_config . profiler_active
6259 wait = profile_freq - (active + warmup )
6360 assert (
6461 wait >= 0
Original file line number Diff line number Diff line change 3232 maybe_enable_memory_snapshot ,
3333 maybe_enable_profiling ,
3434)
35+ import signal
3536
3637
3738class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
@@ -588,6 +589,15 @@ def train(self):
588589 ),
589590 ),
590591 ):
592+ if torch_profiler :
593+ def sigabrt_handler (signal , frame ):
594+ logger .info ("SIGABRT received. Stopping profiler" )
595+ for _ in range (config .profiling .profiler_active ):
596+ # Step the profiler enough times to trigger a trace
597+ torch_profiler .step ()
598+ torch_profiler .stop ()
599+ signal .signal (signal .SIGABRT , sigabrt_handler )
600+
591601 data_iterator = self .batch_generator (self .dataloader )
592602 while self .should_continue_training ():
593603 self .step += 1
You can’t perform that action at this time.
0 commit comments