44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import ctypes
78import importlib
89import os
10+ import signal
911import time
1012from datetime import timedelta
1113from typing import Any , Generator , Iterable , Optional
3234 maybe_enable_profiling ,
3335)
3436
37+ c_globals = ctypes .CDLL (None ) # POSIX
38+
3539
3640class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
41+ torch_profiler : torch .profiler .profile | None = None
42+
3743 # core configs
3844 job_config : JobConfig
3945 parallel_dims : ParallelDims
@@ -582,13 +588,14 @@ def train(self):
582588 if not self .ft_manager .enabled
583589 else f"replica_{ self .ft_manager .replica_id } "
584590 )
591+ self .torch_profiler = maybe_enable_profiling (
592+ job_config .profiling ,
593+ global_step = self .step ,
594+ base_folder = job_config .job .dump_folder ,
595+ leaf_folder = leaf_folder ,
596+ )
597+
585598 with (
586- maybe_enable_profiling (
587- job_config .profiling ,
588- global_step = self .step ,
589- base_folder = job_config .job .dump_folder ,
590- leaf_folder = leaf_folder ,
591- ) as torch_profiler ,
592599 maybe_enable_memory_snapshot (
593600 job_config .profiling ,
594601 global_step = self .step ,
@@ -612,6 +619,15 @@ def train(self):
612619 ),
613620 ),
614621 ):
622+ if self .torch_profiler :
623+
624+ @ctypes .CFUNCTYPE (None , ctypes .c_int )
625+ def sigabrt_handler (signal ):
626+ logger .info ("SIGABRT received. Stopping profiler" )
627+ self .torch_profiler .export_chrome_trace ("trace.json" )
628+
629+ c_globals .signal (signal .SIGABRT , sigabrt_handler )
630+
615631 data_iterator = self .batch_generator (self .dataloader )
616632 while self .should_continue_training ():
617633 self .step += 1
@@ -635,8 +651,8 @@ def train(self):
635651 self .validator .validate (self .model_parts , self .step )
636652
637653 # signal the profiler that the next profiling step has started
638- if torch_profiler :
639- torch_profiler .step ()
654+ if self . torch_profiler :
655+ self . torch_profiler .step ()
640656 if memory_profiler :
641657 memory_profiler .step ()
642658
@@ -694,10 +710,12 @@ def close(self) -> None:
694710 else :
695711 trainer .train ()
696712 except Exception :
713+ logger .info ("Torchtitan training threw an exception" )
697714 if trainer :
698715 trainer .close ()
699716 raise
700717 else :
718+ logger .info ("Torchtitan training completed" )
701719 trainer .close ()
702720 torch .distributed .destroy_process_group ()
703721 logger .info ("Process group destroyed" )
0 commit comments