Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,25 @@ def __init__(
self.enable = checkpoint_config.enable
self.load_only = checkpoint_config.load_only

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)

self.ft_manager = (
ft_manager.manager
if ft_manager
and ft_manager.enabled
and checkpoint_config.enable_ft_dataloader_checkpoints
else None
ft_manager.manager if ft_manager and ft_manager.enabled else None
)

if ft_manager and ft_manager.enabled and not self.ft_manager:
self.enable_ft_dataloader_checkpoints = (
self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints
)

if self.ft_manager and not self.enable_ft_dataloader_checkpoints:
logger.warn(
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
Expand Down Expand Up @@ -229,20 +239,11 @@ def load_state_dict(state_dict):
async_mode = checkpoint_config.async_mode.lower()
self.enable_staging = (
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager
) or self.enable_ft_dataloader_checkpoints

if not self.enable and self.ft_manager is None:
if not self.enable and not self.enable_ft_dataloader_checkpoints:
return

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
Expand Down Expand Up @@ -279,7 +280,7 @@ def load_state_dict(state_dict):
if (
async_mode == AsyncMode.ASYNC
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
or self.ft_manager
or self.enable_ft_dataloader_checkpoints
):
self.pg = dist.new_group(backend="gloo")

Expand Down Expand Up @@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
None
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_save(curr_step)

if not self._should_save(curr_step, last_step):
return

begin = time.monotonic()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
if not self.enable_ft_dataloader_checkpoints or (
self.ft_manager and self.ft_manager.participating_rank() == 0
):
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
Expand Down Expand Up @@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
elif self.enable_ft_dataloader_checkpoints:
assert self.ft_manager is not None
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
Expand All @@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_load()

if not self.enable:
Expand Down Expand Up @@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:

states_to_load = self._flattened_model_states_sd(states_to_load)

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
states_to_load.pop(DATALOADER)

return states_to_load
Expand Down Expand Up @@ -805,7 +809,9 @@ def _async_wait(self) -> None:
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
if self.save_future is not None:
self.save_future.result()
elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None:
elif (
self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints
):
if self.save_future is not None:
self.save_future.result()
self.save_future = None
Expand All @@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
and (
not self.enable_ft_dataloader_checkpoints
or (self.ft_manager and self.ft_manager.participating_rank() == 0)
)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
Loading