44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7+ import  ctypes 
78import  importlib 
89import  os 
10+ import  signal 
911import  time 
1012from  datetime  import  timedelta 
1113from  typing  import  Any , Generator , Iterable , Optional 
3335    maybe_enable_profiling ,
3436)
3537
38+ c_globals  =  ctypes .CDLL (None )  # POSIX 
39+ 
3640
3741class  Trainer (torch .distributed .checkpoint .stateful .Stateful ):
42+     torch_profiler : torch .profiler .profile  |  None  =  None 
43+ 
3844    # core configs 
3945    job_config : JobConfig 
4046    parallel_dims : ParallelDims 
@@ -555,13 +561,14 @@ def train(self):
555561            if  not  self .ft_manager .enabled 
556562            else  f"replica_{ self .ft_manager .replica_id }  
557563        )
564+         self .torch_profiler  =  maybe_enable_profiling (
565+             job_config .profiling ,
566+             global_step = self .step ,
567+             base_folder = job_config .job .dump_folder ,
568+             leaf_folder = leaf_folder ,
569+         )
570+ 
558571        with  (
559-             maybe_enable_profiling (
560-                 job_config .profiling ,
561-                 global_step = self .step ,
562-                 base_folder = job_config .job .dump_folder ,
563-                 leaf_folder = leaf_folder ,
564-             ) as  torch_profiler ,
565572            maybe_enable_memory_snapshot (
566573                job_config .profiling ,
567574                global_step = self .step ,
@@ -585,6 +592,16 @@ def train(self):
585592                ),
586593            ),
587594        ):
595+             if  self .torch_profiler :
596+ 
597+                 @ctypes .CFUNCTYPE (None , ctypes .c_int ) 
598+                 def  sigabrt_handler (signal ):
599+                     logger .info ("SIGABRT received. Stopping profiler" )
600+                     self .torch_profiler .profiler .enabled  =  False 
601+                     self .torch_profiler .export_chrome_trace ("trace.json" )
602+ 
603+                 c_globals .signal (signal .SIGABRT , sigabrt_handler )
604+ 
588605            data_iterator  =  self .batch_generator (self .dataloader )
589606            while  self .should_continue_training ():
590607                self .step  +=  1 
@@ -608,8 +625,8 @@ def train(self):
608625                        self .validator .validate (self .model_parts , self .step )
609626
610627                # signal the profiler that the next profiling step has started 
611-                 if  torch_profiler :
612-                     torch_profiler .step ()
628+                 if  self . torch_profiler :
629+                     self . torch_profiler .step ()
613630                if  memory_profiler :
614631                    memory_profiler .step ()
615632
@@ -667,10 +684,12 @@ def close(self) -> None:
667684        else :
668685            trainer .train ()
669686    except  Exception :
687+         logger .info ("Torchtitan training threw an exception" )
670688        if  trainer :
671689            trainer .close ()
672690        raise 
673691    else :
692+         logger .info ("Torchtitan training completed" )
674693        trainer .close ()
675694        torch .distributed .destroy_process_group ()
676695        logger .info ("Process group destroyed" )
0 commit comments