44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import ctypes
78import importlib
89import os
10+ import signal
911import time
1012from datetime import timedelta
1113from typing import Any , Generator , Iterable , Optional
3335 maybe_enable_profiling ,
3436)
3537
38+ c_globals = ctypes .CDLL (None ) # POSIX
39+
3640
3741class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
42+ torch_profiler : torch .profiler .profile | None = None
43+
3844 # core configs
3945 job_config : JobConfig
4046 parallel_dims : ParallelDims
@@ -555,13 +561,14 @@ def train(self):
555561 if not self .ft_manager .enabled
556562 else f"replica_{ self .ft_manager .replica_id } "
557563 )
564+ self .torch_profiler = maybe_enable_profiling (
565+ job_config .profiling ,
566+ global_step = self .step ,
567+ base_folder = job_config .job .dump_folder ,
568+ leaf_folder = leaf_folder ,
569+ )
570+
558571 with (
559- maybe_enable_profiling (
560- job_config .profiling ,
561- global_step = self .step ,
562- base_folder = job_config .job .dump_folder ,
563- leaf_folder = leaf_folder ,
564- ) as torch_profiler ,
565572 maybe_enable_memory_snapshot (
566573 job_config .profiling ,
567574 global_step = self .step ,
@@ -585,6 +592,16 @@ def train(self):
585592 ),
586593 ),
587594 ):
595+ if self .torch_profiler :
596+
597+ @ctypes .CFUNCTYPE (None , ctypes .c_int )
598+ def sigabrt_handler (signal ):
599+ logger .info ("SIGABRT received. Stopping profiler" )
600+ self .torch_profiler .profiler .enabled = False
601+ self .torch_profiler .export_chrome_trace ("trace.json" )
602+
603+ c_globals .signal (signal .SIGABRT , sigabrt_handler )
604+
588605 data_iterator = self .batch_generator (self .dataloader )
589606 while self .should_continue_training ():
590607 self .step += 1
@@ -608,8 +625,8 @@ def train(self):
608625 self .validator .validate (self .model_parts , self .step )
609626
610627 # signal the profiler that the next profiling step has started
611- if torch_profiler :
612- torch_profiler .step ()
628+ if self . torch_profiler :
629+ self . torch_profiler .step ()
613630 if memory_profiler :
614631 memory_profiler .step ()
615632
@@ -667,10 +684,12 @@ def close(self) -> None:
667684 else :
668685 trainer .train ()
669686 except Exception :
687+ logger .info ("Torchtitan training threw an exception" )
670688 if trainer :
671689 trainer .close ()
672690 raise
673691 else :
692+ logger .info ("Torchtitan training completed" )
674693 trainer .close ()
675694 torch .distributed .destroy_process_group ()
676695 logger .info ("Process group destroyed" )
0 commit comments