@@ -193,17 +193,20 @@ def __init__(
193193 ft_manager .manager
194194 if ft_manager
195195 and ft_manager .enabled
196- and checkpoint_config .enable_ft_dataloader_checkpoints
197196 else None
198197 )
199198
200- if ft_manager and ft_manager .enabled and not self .ft_manager :
199+ self .enable_ft_dataloader_checkpoints = (
200+ self .ft_manager is not None and checkpoint_config .enable_ft_dataloader_checkpoints
201+ )
202+
203+ if self .ft_manager is not None and not self .enable_ft_dataloader_checkpoints :
201204 logger .warn (
202205 "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203206 "This means replicas can retrain over the same data multiple times, which can result in overfitting."
204207 )
205208
206- if self .ft_manager :
209+ if self .ft_manager is not None :
207210 optimizers .init_cache_state_dict ()
208211
209212 def state_dict ():
@@ -229,20 +232,22 @@ def load_state_dict(state_dict):
229232 async_mode = checkpoint_config .async_mode .lower ()
230233 self .enable_staging = (
231234 self .enable and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
232- ) or self .ft_manager
235+ ) or self .enable_ft_dataloader_checkpoints
236+
237+ if self .enable or self .ft_manager is not None :
238+ self .states = states
239+ self .states .update (
240+ {
241+ MODEL : ModelWrapper (model_parts ),
242+ OPTIMIZER : optimizers ,
243+ DATALOADER : dataloader ,
244+ LR_SCHEDULER : lr_schedulers ,
245+ }
246+ )
233247
234- if not self .enable and self .ft_manager is None :
248+ if not self .enable and not self .enable_ft_dataloader_checkpoints :
235249 return
236250
237- self .states = states
238- self .states .update (
239- {
240- MODEL : ModelWrapper (model_parts ),
241- OPTIMIZER : optimizers ,
242- DATALOADER : dataloader ,
243- LR_SCHEDULER : lr_schedulers ,
244- }
245- )
246251 self .ft_states = {DATALOADER : dataloader }
247252
248253 self .staging = False
@@ -279,7 +284,7 @@ def load_state_dict(state_dict):
279284 if (
280285 async_mode == AsyncMode .ASYNC
281286 or async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
282- or self .ft_manager
287+ or self .enable_ft_dataloader_checkpoints
283288 ):
284289 self .pg = dist .new_group (backend = "gloo" )
285290
@@ -480,7 +485,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
480485 None
481486 """
482487
483- if self .ft_manager :
488+ if self .enable_ft_dataloader_checkpoints :
484489 self ._ft_save (curr_step )
485490
486491 if not self ._should_save (curr_step , last_step ):
@@ -551,7 +556,7 @@ def load(self, step: int = -1) -> bool:
551556 bool: Whether the checkpoint was loaded successfully.
552557 """
553558
554- if self .ft_manager :
559+ if self .enable_ft_dataloader_checkpoints :
555560 self ._ft_load ()
556561
557562 if not self .enable :
@@ -749,7 +754,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
749754
750755 states_to_load = self ._flattened_model_states_sd (states_to_load )
751756
752- if self .ft_manager :
757+ if self .enable_ft_dataloader_checkpoints :
753758 states_to_load .pop (DATALOADER )
754759
755760 return states_to_load
0 commit comments