44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import ctypes
78import importlib
89import os
10+ import signal
911import time
1012from datetime import timedelta
1113from typing import Any , Generator , Iterable , Optional
3234 maybe_enable_profiling ,
3335)
3436
37+ c_globals = ctypes .CDLL (None ) # POSIX
38+
3539
3640class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
41+ torch_profiler : torch .profiler .profile | None = None
42+
3743 # core configs
3844 job_config : JobConfig
3945 parallel_dims : ParallelDims
@@ -569,13 +575,14 @@ def train(self):
569575 if not self .ft_manager .enabled
570576 else f"replica_{ self .ft_manager .replica_id } "
571577 )
578+ self .torch_profiler = maybe_enable_profiling (
579+ job_config .profiling ,
580+ global_step = self .step ,
581+ base_folder = job_config .job .dump_folder ,
582+ leaf_folder = leaf_folder ,
583+ )
584+
572585 with (
573- maybe_enable_profiling (
574- job_config .profiling ,
575- global_step = self .step ,
576- base_folder = job_config .job .dump_folder ,
577- leaf_folder = leaf_folder ,
578- ) as torch_profiler ,
579586 maybe_enable_memory_snapshot (
580587 job_config .profiling ,
581588 global_step = self .step ,
@@ -599,6 +606,16 @@ def train(self):
599606 ),
600607 ),
601608 ):
609+ if self .torch_profiler :
610+
611+ @ctypes .CFUNCTYPE (None , ctypes .c_int )
612+ def sigabrt_handler (signal ):
613+ logger .info ("SIGABRT received. Stopping profiler" )
614+ self .torch_profiler .profiler .enabled = False
615+ self .torch_profiler .export_chrome_trace ("trace.json" )
616+
617+ c_globals .signal (signal .SIGABRT , sigabrt_handler )
618+
602619 data_iterator = self .batch_generator (self .dataloader )
603620 while self .should_continue_training ():
604621 self .step += 1
@@ -622,8 +639,8 @@ def train(self):
622639 self .validator .validate (self .model_parts , self .step )
623640
624641 # signal the profiler that the next profiling step has started
625- if torch_profiler :
626- torch_profiler .step ()
642+ if self . torch_profiler :
643+ self . torch_profiler .step ()
627644 if memory_profiler :
628645 memory_profiler .step ()
629646
@@ -681,10 +698,12 @@ def close(self) -> None:
681698 else :
682699 trainer .train ()
683700 except Exception :
701+ logger .info ("Torchtitan training threw an exception" )
684702 if trainer :
685703 trainer .close ()
686704 raise
687705 else :
706+ logger .info ("Torchtitan training completed" )
688707 trainer .close ()
689708 torch .distributed .destroy_process_group ()
690709 logger .info ("Process group destroyed" )
0 commit comments