Skip to content

Commit e5606c9

Browse files
committed
fix setting ft state dicts when ft checkpointing is disabled
Summary: - when ft dataloader checkpointing is disabled, we also don't set the ft state - make it so that when ft checkpointing is disabled, we still set the state dict so that model, optimizer etc. can be recovered from a different replica
1 parent e43621c commit e5606c9

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

torchtitan/components/checkpoint.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,20 @@ def __init__(
193193
ft_manager.manager
194194
if ft_manager
195195
and ft_manager.enabled
196-
and checkpoint_config.enable_ft_dataloader_checkpoints
197196
else None
198197
)
199198

200-
if ft_manager and ft_manager.enabled and not self.ft_manager:
199+
self.enable_ft_dataloader_checkpoints = (
200+
self.ft_manager is not None and checkpoint_config.enable_ft_dataloader_checkpoints
201+
)
202+
203+
if self.ft_manager is not None and not self.enable_ft_dataloader_checkpoints:
201204
logger.warn(
202205
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203206
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
204207
)
205208

206-
if self.ft_manager:
209+
if self.ft_manager is not None:
207210
optimizers.init_cache_state_dict()
208211

209212
def state_dict():
@@ -229,20 +232,22 @@ def load_state_dict(state_dict):
229232
async_mode = checkpoint_config.async_mode.lower()
230233
self.enable_staging = (
231234
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
232-
) or self.ft_manager
235+
) or self.enable_ft_dataloader_checkpoints
236+
237+
if self.enable or self.ft_manager is not None:
238+
self.states = states
239+
self.states.update(
240+
{
241+
MODEL: ModelWrapper(model_parts),
242+
OPTIMIZER: optimizers,
243+
DATALOADER: dataloader,
244+
LR_SCHEDULER: lr_schedulers,
245+
}
246+
)
233247

234-
if not self.enable and self.ft_manager is None:
248+
if not self.enable and not self.enable_ft_dataloader_checkpoints:
235249
return
236250

237-
self.states = states
238-
self.states.update(
239-
{
240-
MODEL: ModelWrapper(model_parts),
241-
OPTIMIZER: optimizers,
242-
DATALOADER: dataloader,
243-
LR_SCHEDULER: lr_schedulers,
244-
}
245-
)
246251
self.ft_states = {DATALOADER: dataloader}
247252

248253
self.staging = False
@@ -279,7 +284,7 @@ def load_state_dict(state_dict):
279284
if (
280285
async_mode == AsyncMode.ASYNC
281286
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
282-
or self.ft_manager
287+
or self.enable_ft_dataloader_checkpoints
283288
):
284289
self.pg = dist.new_group(backend="gloo")
285290

@@ -480,7 +485,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
480485
None
481486
"""
482487

483-
if self.ft_manager:
488+
if self.enable_ft_dataloader_checkpoints:
484489
self._ft_save(curr_step)
485490

486491
if not self._should_save(curr_step, last_step):
@@ -551,7 +556,7 @@ def load(self, step: int = -1) -> bool:
551556
bool: Whether the checkpoint was loaded successfully.
552557
"""
553558

554-
if self.ft_manager:
559+
if self.enable_ft_dataloader_checkpoints:
555560
self._ft_load()
556561

557562
if not self.enable:
@@ -749,7 +754,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
749754

750755
states_to_load = self._flattened_model_states_sd(states_to_load)
751756

752-
if self.ft_manager:
757+
if self.enable_ft_dataloader_checkpoints:
753758
states_to_load.pop(DATALOADER)
754759

755760
return states_to_load

0 commit comments

Comments
 (0)