Skip to content

Commit

Permalink
Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed train…
Browse files Browse the repository at this point in the history
…ing (NVIDIA#8480)

* Fix PTL2.2 saving multiple `*-last.ckpt` checkpoints when resuming from previous run

Signed-off-by: He Huang (Steve) <[email protected]>

* Fix missing import

Signed-off-by: He Huang (Steve) <[email protected]>

* fix broken test

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: He Huang (Steve) <[email protected]>
Signed-off-by: stevehuang52 <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
  • Loading branch information
stevehuang52 and athitten authored Mar 23, 2024
1 parent 35f9d34 commit 11b7a73
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
28 changes: 27 additions & 1 deletion nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
from pytorch_lightning.utilities import rank_zero_info

from nemo.collections.common.callbacks import EMA
Expand Down Expand Up @@ -454,3 +454,29 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
# delete markers
for marker_path in existing_marker_filepaths:
os.remove(marker_path)

def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
"""Checks if the previous checkpoint should be deleted.
A checkpoint won't be deleted if any of the cases apply:
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
and the resumed from checkpoint is not the last checkpoint
"""
if previous == current:
return False
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None

if resume_path is not None and previous == resume_path:
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory
pass
else:
return False
if self.dirpath is None:
raise ValueError(f"{self.__class__}.dirpath is None.")
dirpath = Path(self.dirpath).absolute()
return dirpath in previous.parents
5 changes: 2 additions & 3 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,8 @@ def test_invalid_checkpoints_removed_from_topk(self, tmp_path):
test_trainer2.fit(model)

ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()}
# 3 top + 1 last + 1 resume ckpt since PTL >= 2.1 ensures to never delete the resume ckpt
# (https://github.com/Lightning-AI/pytorch-lightning/pull/18750)
assert len(ckpt_filenames) == 5
# 3 top + 1 last
assert len(ckpt_filenames) == 4
assert 'epoch=9-last.ckpt' in ckpt_filenames
assert 'epoch=8.ckpt' in ckpt_filenames
assert 'epoch=7.ckpt' in ckpt_filenames
Expand Down

0 comments on commit 11b7a73

Please sign in to comment.