@@ -667,20 +667,26 @@ def _ft_load(self) -> None:
667
667
step = self ._find_load_step (folder = self ._ft_folder ())
668
668
if step == - 1 :
669
669
return
670
-
671
- begin = time .monotonic ()
672
- logger .info (f"Loading the FT checkpoint at step { step } ." )
673
- checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
674
- self .dcp_load (
675
- self .ft_states ,
676
- checkpoint_id = checkpoint_id ,
677
- # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
678
- from_hf = False ,
679
- )
680
- GarbageCollection .collect ("GC collection for checkpoint loading." )
681
- logger .info (
682
- f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
683
- )
670
+ try :
671
+ begin = time .monotonic ()
672
+ logger .info (f"Loading the FT checkpoint at step { step } ." )
673
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
674
+ logger .info (f"Calling dcp_load for { checkpoint_id } " )
675
+ self .dcp_load (
676
+ self .ft_states ,
677
+ checkpoint_id = checkpoint_id ,
678
+ # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
679
+ from_hf = False ,
680
+ )
681
+ GarbageCollection .collect ("GC collection for checkpoint loading." )
682
+ logger .info (
683
+ f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
684
+ )
685
+ except Exception as e :
686
+ # The checkpoint is corrupt. We'll replay all data here.
687
+ # TODO: We can try to load checkpoint from previous steps.
688
+ logger .error ("Failed to load the FT checkpoint." )
689
+ return
684
690
685
691
def _flattened_model_states_sd (
686
692
self , state_dict : dict [str , Any ] | None = None
0 commit comments