Skip to content

Commit 705204f

Browse files
committed
allow disabling ft checkpoints
Summary: Allows disabling the storage of checkpoints related to torchft. Users don't really have to rely on any external storage. So it reduces set up time to get things up and running. Since we also don't really need model checkpoints when we have torchft. And if checkpoint storage has issues, this can work as a killswitch to completely disable the storage so it doesn't impact training.
1 parent 54d2a8b commit 705204f

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

torchtitan/components/checkpoint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,18 @@ def __init__(
193193
self.load_only = checkpoint_config.load_only
194194

195195
self.ft_manager = (
196-
ft_manager.manager if ft_manager and ft_manager.enabled else None
196+
ft_manager.manager
197+
if ft_manager
198+
and ft_manager.enabled
199+
and checkpoint_config.enable_ft_dataloader_checkpoints
200+
else None
197201
)
202+
203+
if ft_manager and ft_manager.enabled and not self.ft_manager:
204+
logger.warn(
205+
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. This means replicas can retrain over the same data multiple times, which can result in overfitting."
206+
)
207+
198208
if self.ft_manager:
199209
optimizers.init_cache_state_dict()
200210

torchtitan/config/job_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,28 @@ class Checkpoint:
422422
enable: bool = False
423423
"""Whether to enable checkpoint"""
424424

425+
enable_ft_dataloader_checkpoints: bool = True
426+
"""
427+
Warning: Disabling this can have fault tolerant replicas training
428+
over the same data multiple times. Use it with caution if training
429+
over the same data is acceptable.
430+
431+
Used to enable checkpointing the dataloader index for fault tolerant training with torchft.
432+
433+
Fault tolerant training stores data loader index in the checkpoints, so that training can resume
434+
without going over the same batch twice.
435+
436+
If enabled, data loader state is checkpointed. Otherwise, replicas
437+
will train over the same data multiple times, which can result in
438+
overfitting.
439+
440+
The failed replcia will still recover other state e.g. model
441+
parameters from other replcias.
442+
443+
Note, if regular checkpointing is enabled, we also checkpoint the
444+
data loader state. But when not using fault tolerance, the entire training starts from scratch.
445+
"""
446+
425447
folder: str = "checkpoint"
426448
"""
427449
The folder to store the checkpoints.

0 commit comments

Comments
 (0)