File tree Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,12 @@ class Profiling:
34
34
profile_freq : int = 10
35
35
"""How often to collect profile traces, in iterations"""
36
36
37
+ profiler_active : int = 1
38
+ """The steps profiler is active for"""
39
+
40
+ profiler_warmup : int = 3
41
+ """The number of warmup steps before the active step in each profiling cycle"""
42
+
37
43
enable_memory_snapshot : bool = False
38
44
"""Whether to dump memory snapshot"""
39
45
Original file line number Diff line number Diff line change 14
14
from torchtitan .config import Profiling as ProfilingConfig
15
15
from torchtitan .tools .logging import logger
16
16
17
- # the number of warmup steps before the active step in each profiling cycle
18
- WARMUP = 3
19
-
20
17
# how much memory allocation/free ops to record in memory snapshots
21
18
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
22
19
@@ -58,7 +55,7 @@ def trace_handler(prof):
58
55
if not os .path .exists (trace_dir ):
59
56
os .makedirs (trace_dir , exist_ok = True )
60
57
61
- warmup , active = WARMUP , 1
58
+ warmup , active = profiling_config . profiler_warmup , profiling_config . profiler_active
62
59
wait = profile_freq - (active + warmup )
63
60
assert (
64
61
wait >= 0
Original file line number Diff line number Diff line change 32
32
maybe_enable_memory_snapshot ,
33
33
maybe_enable_profiling ,
34
34
)
35
+ import signal
35
36
36
37
37
38
class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
@@ -585,6 +586,15 @@ def train(self):
585
586
),
586
587
),
587
588
):
589
+ if torch_profiler :
590
+ def sigabrt_handler (signal , frame ):
591
+ logger .info ("SIGABRT received. Stopping profiler" )
592
+ for _ in range (config .profiling .profiler_active ):
593
+ # Step the profiler enough times to trigger a trace
594
+ torch_profiler .step ()
595
+ torch_profiler .stop ()
596
+ signal .signal (signal .SIGABRT , sigabrt_handler )
597
+
588
598
data_iterator = self .batch_generator (self .dataloader )
589
599
while self .should_continue_training ():
590
600
self .step += 1
You can’t perform that action at this time.
0 commit comments