Skip to content

Commit e2a1a98

Browse files
committed
handle unable to load ft checkpoint
Summary: - not being able to load ft checkpoint crashes the trainer - avoid loading the ft checkpoint for now to continue training
1 parent 424b23c commit e2a1a98

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

torchtitan/components/checkpoint.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -667,20 +667,26 @@ def _ft_load(self) -> None:
667667
step = self._find_load_step(folder=self._ft_folder())
668668
if step == -1:
669669
return
670-
671-
begin = time.monotonic()
672-
logger.info(f"Loading the FT checkpoint at step {step}.")
673-
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
674-
self.dcp_load(
675-
self.ft_states,
676-
checkpoint_id=checkpoint_id,
677-
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
678-
from_hf=False,
679-
)
680-
GarbageCollection.collect("GC collection for checkpoint loading.")
681-
logger.info(
682-
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
683-
)
670+
try:
671+
begin = time.monotonic()
672+
logger.info(f"Loading the FT checkpoint at step {step}.")
673+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
674+
logger.info(f"Calling dcp_load for {checkpoint_id}")
675+
self.dcp_load(
676+
self.ft_states,
677+
checkpoint_id=checkpoint_id,
678+
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
679+
from_hf=False,
680+
)
681+
GarbageCollection.collect("GC collection for checkpoint loading.")
682+
logger.info(
683+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
684+
)
685+
except Exception as e:
686+
# The checkpoint is corrupt. We'll replay all data here.
687+
# TODO: We can try to load checkpoint from previous steps.
688+
logger.error("Failed to load the FT checkpoint.")
689+
return
684690

685691
def _flattened_model_states_sd(
686692
self, state_dict: dict[str, Any] | None = None

0 commit comments

Comments
 (0)