Skip to content

Commit 22239d9

Browse files
committed
allow disabling ft checkpoints
Summary: Allows disabling the storage of checkpoints related to torchft. Users don't really have to rely on any external storage. So it reduces set up time to get things up and running. Since we also don't really need model checkpoints when we have torchft. And if checkpoint storage has issues, this can work as a killswitch to completely disable the storage so it doesn't impact training.
1 parent cf47976 commit 22239d9

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

torchtitan/components/checkpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,14 @@ def __init__(
192192
self.enable = checkpoint_config.enable
193193
self.load_only = checkpoint_config.load_only
194194

195+
# Warning: If fault tolerance is enabled enable_ft_dataloader_checkpoints is False, replicas can
196+
# retrain over the same data multiple times, which can result in overfitting.
195197
self.ft_manager = (
196-
ft_manager.manager if ft_manager and ft_manager.enabled else None
198+
ft_manager.manager
199+
if ft_manager
200+
and ft_manager.enabled
201+
and checkpoint_config.enable_ft_dataloader_checkpoints
202+
else None
197203
)
198204
if self.ft_manager:
199205
optimizers.init_cache_state_dict()

torchtitan/config/job_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,28 @@ class Checkpoint:
422422
enable: bool = False
423423
"""Whether to enable checkpoint"""
424424

425+
enable_ft_dataloader_checkpoints: bool = True
426+
"""
427+
Warning: Disabling this can have fault tolerant replicas training
428+
over the same data multiple times. Use it with caution if training
429+
over the same data is acceptable.
430+
431+
Used to enable checkpointing the dataloader index for fault tolerant training with torchft.
432+
433+
Fault tolerant training stores data loader index in the checkpoints, so that training can resume
434+
without going over the same batch twice.
435+
436+
If enabled, data loader state is checkpointed. Otherwise, replicas
437+
will train over the same data multiple times, which can result in
438+
overfitting.
439+
440+
The failed replcia will still recover other state e.g. model
441+
parameters from other replcias.
442+
443+
Note, if regular checkpointing is enabled, we also checkpoint the
444+
data loader state. But when not using fault tolerance, the entire training starts from scratch.
445+
"""
446+
425447
folder: str = "checkpoint"
426448
"""
427449
The folder to store the checkpoints.

0 commit comments

Comments
 (0)