diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index d4c5416aa2..79918d0046 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -189,15 +189,25 @@ def __init__( self.enable = checkpoint_config.enable self.load_only = checkpoint_config.load_only + self.states = states + self.states.update( + { + MODEL: ModelWrapper(model_parts), + OPTIMIZER: optimizers, + DATALOADER: dataloader, + LR_SCHEDULER: lr_schedulers, + } + ) + self.ft_manager = ( - ft_manager.manager - if ft_manager - and ft_manager.enabled - and checkpoint_config.enable_ft_dataloader_checkpoints - else None + ft_manager.manager if ft_manager and ft_manager.enabled else None ) - if ft_manager and ft_manager.enabled and not self.ft_manager: + self.enable_ft_dataloader_checkpoints = ( + self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints + ) + + if self.ft_manager and not self.enable_ft_dataloader_checkpoints: logger.warn( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." @@ -229,20 +239,11 @@ def load_state_dict(state_dict): async_mode = checkpoint_config.async_mode.lower() self.enable_staging = ( self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - ) or self.ft_manager + ) or self.enable_ft_dataloader_checkpoints - if not self.enable and self.ft_manager is None: + if not self.enable and not self.enable_ft_dataloader_checkpoints: return - self.states = states - self.states.update( - { - MODEL: ModelWrapper(model_parts), - OPTIMIZER: optimizers, - DATALOADER: dataloader, - LR_SCHEDULER: lr_schedulers, - } - ) self.ft_states = {DATALOADER: dataloader} self.staging = False @@ -279,7 +280,7 @@ def load_state_dict(state_dict): if ( async_mode == AsyncMode.ASYNC or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - or self.ft_manager + or self.enable_ft_dataloader_checkpoints ): self.pg = dist.new_group(backend="gloo") @@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None: None """ - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: self._ft_save(curr_step) if not self._should_save(curr_step, last_step): return begin = time.monotonic() - if not self.ft_manager or self.ft_manager.participating_rank() == 0: + if not self.enable_ft_dataloader_checkpoints or ( + self.ft_manager and self.ft_manager.participating_rank() == 0 + ): logger.info("Saving the checkpoint (or staging if async is enabled).") checkpoint_id = self._create_checkpoint_id(curr_step) self._async_wait() @@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None: "Finished saving the checkpoint (or staging if async is enabled)" f"in {time.monotonic() - begin:.2f} seconds." ) - elif self.ft_manager: + elif self.enable_ft_dataloader_checkpoints: + assert self.ft_manager is not None logger.info( "Replica %d doesn't save checkpoint.", self.ft_manager.participating_rank(), @@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool: bool: Whether the checkpoint was loaded successfully. """ - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: self._ft_load() if not self.enable: @@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: states_to_load = self._flattened_model_states_sd(states_to_load) - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: states_to_load.pop(DATALOADER) return states_to_load @@ -805,7 +809,9 @@ def _async_wait(self) -> None: if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: if self.save_future is not None: self.save_future.result() - elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None: + elif ( + self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints + ): if self.save_future is not None: self.save_future.result() self.save_future = None @@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self): self.keep_latest_k > 0 and dist.get_rank() == 0 and os.path.isdir(self.folder) - and (not self.ft_manager or self.ft_manager.participating_rank() == 0) + and ( + not self.enable_ft_dataloader_checkpoints + or (self.ft_manager and self.ft_manager.participating_rank() == 0) + ) ): discovered_checkpoints = [] for filename in os.listdir(self.folder):