Skip to content

Commit d3b5640

Browse files
committed
fix setting ft state dicts when ft checkpointing is disabled
Summary: - when ft dataloader checkpointing is disabled, we also don't set the ft state - make it so that when ft checkpointing is disabled, we still set the state dict so that model, optimizer etc. can be recovered from a different replica
1 parent e43621c commit d3b5640

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

torchtitan/components/checkpoint.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,30 @@ def __init__(
189189
self.enable = checkpoint_config.enable
190190
self.load_only = checkpoint_config.load_only
191191

192+
self.states = states
193+
self.states.update(
194+
{
195+
MODEL: ModelWrapper(model_parts),
196+
OPTIMIZER: optimizers,
197+
DATALOADER: dataloader,
198+
LR_SCHEDULER: lr_schedulers,
199+
}
200+
)
201+
202+
ft_inner_manager = ft_manager.manager if ft_manager and ft_manager.enabled else None
192203
self.ft_manager = (
193-
ft_manager.manager
194-
if ft_manager
195-
and ft_manager.enabled
196-
and checkpoint_config.enable_ft_dataloader_checkpoints
204+
ft_inner_manager
205+
if ft_inner_manager and checkpoint_config.enable_ft_dataloader_checkpoints
197206
else None
198207
)
199208

200-
if ft_manager and ft_manager.enabled and not self.ft_manager:
209+
if ft_inner_manager and not self.ft_manager:
201210
logger.warn(
202211
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203212
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
204213
)
205214

206-
if self.ft_manager:
215+
if ft_inner_manager:
207216
optimizers.init_cache_state_dict()
208217

209218
def state_dict():
@@ -223,7 +232,7 @@ def load_state_dict(state_dict):
223232
for k, v in state_dict.items():
224233
self.states[k].load_state_dict(v)
225234

226-
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
235+
ft_inner_manager.set_state_dict_fns(load_state_dict, state_dict)
227236
self.ft_replica_id = ft_manager.replica_id
228237

229238
async_mode = checkpoint_config.async_mode.lower()
@@ -234,15 +243,6 @@ def load_state_dict(state_dict):
234243
if not self.enable and self.ft_manager is None:
235244
return
236245

237-
self.states = states
238-
self.states.update(
239-
{
240-
MODEL: ModelWrapper(model_parts),
241-
OPTIMIZER: optimizers,
242-
DATALOADER: dataloader,
243-
LR_SCHEDULER: lr_schedulers,
244-
}
245-
)
246246
self.ft_states = {DATALOADER: dataloader}
247247

248248
self.staging = False

0 commit comments

Comments
 (0)