From ac73938051d70c4ca9ac8f8b2ad3195d93fe5f7a Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 5 Nov 2024 18:19:29 +0000 Subject: [PATCH 01/16] Fix checkpoint progress for fit loop and batch loop --- .../pytorch/loops/evaluation_loop.py | 9 ++ src/lightning/pytorch/loops/fit_loop.py | 42 ++++++++-- .../pytorch/loops/training_epoch_loop.py | 20 +++++ tests/tests_pytorch/loops/test_loops.py | 84 ++++++++++++++++++- 4 files changed, 146 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf072d..a94791c93a919 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -197,6 +197,15 @@ def setup_data(self) -> None: # this depends on the data used, so reset it too self._seen_batches_per_dataloader = defaultdict(int) + @property + def restarting_on_evaluation_end(self) -> bool: + return ( + self.restarting + and self.batch.progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ) + def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index eb30e32757c9a..e1b9d42f6e09e 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -302,12 +302,39 @@ def setup_data(self) -> None: category=PossibleUserWarning, ) + @property + def restarting_on_epoch_start(self) -> bool: + return ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ) + + @property + def restarting_mid_epoch(self) -> bool: + return ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ) + + @property + def restarting_on_epoch_end(self) -> bool: + return ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ) + def reset(self) -> None: """Resets the internal state of this loop.""" assert self.trainer.model is not None torch.set_grad_enabled(True) - if self.restarting: + if self.restarting_on_epoch_start: self.epoch_progress.reset_on_restart() def on_run_start(self) -> None: @@ -340,12 +367,14 @@ def on_advance_start(self) -> None: for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed) - self.epoch_progress.increment_ready() + if not self.restarting_mid_epoch and not self.restarting_on_epoch_end: + if not self.restarting_on_epoch_start: + self.epoch_progress.increment_ready() - call._call_callback_hooks(trainer, "on_train_epoch_start") - call._call_lightning_module_hook(trainer, "on_train_epoch_start") + call._call_callback_hooks(trainer, "on_train_epoch_start") + call._call_lightning_module_hook(trainer, "on_train_epoch_start") - self.epoch_progress.increment_started() + self.epoch_progress.increment_started() def advance(self) -> None: """Runs one whole epoch.""" @@ -379,8 +408,7 @@ def on_advance_end(self) -> None: trainer._logger_connector.on_epoch_end() - if self.epoch_loop._num_ready_batches_reached(): - # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. + if not self.restarting and self.epoch_loop._num_ready_batches_reached(): # since metric-based schedulers require access to metrics and those are not currently saved in the # checkpoint, the plateau schedulers shouldn't be updated self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e36ee65176c8..70699c6b208d2 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -144,8 +144,28 @@ def run(self, data_fetcher: _DataFetcher) -> None: break self._restarting = False + @property + def restarting_on_train_batch_end(self) -> bool: + return ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ) + def reset(self) -> None: """Resets the internal state of the loop for a new run.""" + if self.restarting_on_train_batch_end: + self.batch_progress.increment_completed() + # handle situation in which save happened on_train_batch_end and epoch is at end + if self.batch_progress.current.completed >= self.trainer.num_training_batches: + self.batch_progress.reset_on_run() + self.scheduler_progress.reset_on_run() + self.automatic_optimization.optim_progress.reset_on_run() + self.val_loop.batch_progress.total.reset() + if not self._should_accumulate(): + self._batches_that_stepped += 1 + if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e18ba..1e945b19f81b0 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -564,7 +564,7 @@ def test_fit_loop_reset(tmp_path): assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 assert epoch_loop.restarting @@ -594,7 +594,7 @@ def test_fit_loop_reset(tmp_path): assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 assert epoch_loop.restarting @@ -606,6 +606,86 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.batch_progress.current.completed == 3 +def compare_state_dicts(dict1, dict2): + def compare_leaves(d1, d2): + result = {} + all_keys = set(d1.keys()).union(d2.keys()) + + for key in all_keys: + val1 = d1.get(key, None) + val2 = d2.get(key, None) + + if isinstance(val1, dict) and isinstance(val2, dict): + res = compare_leaves(val1, val2) + if res: + result[key] = res + elif isinstance(val1, dict) or isinstance(val2, dict): + raise ValueError("dicts have different leaves") + elif type(val1) == float and type(val2) == float: + if abs(val1 - val2) > 1e-8: + result[key] = f"{val1} != {val2}" + elif val1 != val2: + result[key] = f"{val1} != {val2}" + return result + return compare_leaves(dict1, dict2) + + +def test_restart_at_batch_end(tmp_path): + """ + TODO + """ + + model = BoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + + @pytest.mark.parametrize( ("train_datasets", "val_datasets"), [([RandomDataset], [RandomDataset]), ([RandomDataset], [RandomDataset, RandomDataset])], From 2d5576dbf05de3ff0c6c141c77c83db5471be669 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 5 Nov 2024 18:55:28 +0000 Subject: [PATCH 02/16] Check loss parity --- tests/tests_pytorch/loops/test_loops.py | 52 ++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 1e945b19f81b0..0e1de9aca155a 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Iterator +from typing import Dict, Iterator, Any from unittest.mock import ANY, Mock import pytest @@ -25,6 +25,7 @@ from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter from tests_pytorch.helpers.runif import RunIf @@ -621,6 +622,10 @@ def compare_leaves(d1, d2): result[key] = res elif isinstance(val1, dict) or isinstance(val2, dict): raise ValueError("dicts have different leaves") + elif type(val1) == torch.Tensor and type(val2) == torch.Tensor: + diff = torch.norm(val1 - val2) + if diff > 1e-8: + result[key] = f"{diff} > 1e-8" elif type(val1) == float and type(val2) == float: if abs(val1 - val2) > 1e-8: result[key] = f"{val1} != {val2}" @@ -630,12 +635,48 @@ def compare_leaves(d1, d2): return compare_leaves(dict1, dict2) +class RangeDataset(torch.utils.data.Dataset): + def __init__(self, size: int, length: int): + self.len = length + data = torch.arange(0, size) / size + self.data = data.unsqueeze(0).repeat(length, 1) + + def __getitem__(self, index: int) -> torch.Tensor: + return self.data[index] + + def __len__(self) -> int: + return self.len + + +class PredictableBoringModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.last_loss = float("inf") + + def train_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def val_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def test_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.last_loss = loss + return {"loss": loss} + + def test_restart_at_batch_end(tmp_path): """ TODO """ - model = BoringModel() + model = PredictableBoringModel() checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, every_n_train_steps=2, @@ -650,6 +691,7 @@ def test_restart_at_batch_end(tmp_path): enable_model_summary=False, ) trainer.fit(model) + loss = model.last_loss trainer = Trainer( default_root_dir=tmp_path, @@ -660,6 +702,9 @@ def test_restart_at_batch_end(tmp_path): enable_model_summary=False, ) trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert(abs(loss - loss_v1) < 1e-8) end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) @@ -668,6 +713,7 @@ def test_restart_at_batch_end(tmp_path): assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) @@ -676,6 +722,7 @@ def test_restart_at_batch_end(tmp_path): assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) @@ -684,6 +731,7 @@ def test_restart_at_batch_end(tmp_path): assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} @pytest.mark.parametrize( From 0750b2ea72275a8fdee53a53fbf7a2f5706ae943 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 5 Nov 2024 18:57:53 +0000 Subject: [PATCH 03/16] Rename test --- tests/tests_pytorch/loops/test_loops.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 0e1de9aca155a..9cf571bfa136f 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -671,11 +671,7 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: return {"loss": loss} -def test_restart_at_batch_end(tmp_path): - """ - TODO - """ - +def test_restart_parity(tmp_path): model = PredictableBoringModel() checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, From cbb9fb534e4f5e7340e05546698f5779d0665158 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 6 Nov 2024 12:03:21 +0000 Subject: [PATCH 04/16] Fix validation loop handling on restart --- .../pytorch/loops/evaluation_loop.py | 10 ++- src/lightning/pytorch/loops/progress.py | 25 +++++++ .../pytorch/loops/training_epoch_loop.py | 5 +- tests/tests_pytorch/loops/test_loops.py | 65 +++++++++++++++++++ 4 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index a94791c93a919..7eaf7cf50b8a0 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -201,7 +201,7 @@ def setup_data(self) -> None: def restarting_on_evaluation_end(self) -> bool: return ( self.restarting - and self.batch.progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.started == self.batch_progress.total.ready and self.batch_progress.total.processed == self.batch_progress.total.started and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 ) @@ -245,6 +245,14 @@ def reset(self) -> None: data_fetcher._stop_profiler = self._on_after_fetch self._data_fetcher = data_fetcher + def increment_progress_to_evaluation_end(self) -> None: + self.setup_data() + if self.skip: + return + self.reset() + max_batch = max(self.max_batches) + self.batch_progress.increment_by(max_batch, True) + def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 3d34653122329..807d150a26fb2 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -59,6 +59,7 @@ def reset(self) -> None: self.ready = 0 self.completed = 0 + @override def reset_on_restart(self) -> None: """Reset the progress on restart. @@ -68,6 +69,11 @@ def reset_on_restart(self) -> None: """ self.ready = self.completed + @override + def increment_by(self, n) -> None: + self.ready += n + self.completed += n + @dataclass class _StartedTracker(_ReadyCompletedTracker): @@ -94,6 +100,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.started = self.completed + @override + def increment_by(self, n) -> None: + super().increment_by(n) + self.started += n + @dataclass class _ProcessedTracker(_StartedTracker): @@ -121,6 +132,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.processed = self.completed + @override + def increment_by(self, n) -> None: + super().increment_by(n) + self.processed += n + @dataclass class _Progress(_BaseProgress): @@ -175,6 +191,11 @@ def reset_on_run(self) -> None: def reset_on_restart(self) -> None: self.current.reset_on_restart() + @override + def increment_by(self, n) -> None: + self.total.increment_by(n) + self.current.increment_by(n) + @override def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) @@ -206,6 +227,10 @@ def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False + def increment_by(self, n, is_last_batch=False) -> None: + super().increment_by(n) + self.is_last_batch = is_last_batch + @override def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 70699c6b208d2..5fb98a82b9950 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -217,8 +217,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - # skip training and run validation in `on_advance_end` - return + # fast forward progress counters to end of validation + self.val_loop.increment_progress_to_evaluation_end() + # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 9cf571bfa136f..abb0343bbf3b1 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -730,6 +730,71 @@ def test_restart_parity(tmp_path): assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} +def test_restart_parity_with_val(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model) + loss = model.last_loss + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert(abs(loss - loss_v1) < 1e-8) + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + @pytest.mark.parametrize( ("train_datasets", "val_datasets"), [([RandomDataset], [RandomDataset]), ([RandomDataset], [RandomDataset, RandomDataset])], From 7d0b5a147bc81071bdbdc7e0d5626a37527258f8 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 6 Nov 2024 12:18:16 +0000 Subject: [PATCH 05/16] Fix loop reset test --- tests/tests_pytorch/loops/test_loops.py | 39 ++++++++++++++++++------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index abb0343bbf3b1..d5faf07425215 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -558,9 +558,6 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 - fit_loop.reset() - epoch_loop.reset() assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 @@ -568,14 +565,32 @@ def test_fit_loop_reset(tmp_path): assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 - assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 1 + assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 2 assert epoch_loop.batch_progress.current.completed == 1 + fit_loop.reset() + epoch_loop.reset() + + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + + # however it should increment completed batch progress, since it was saved immediately prior + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.processed == 2 + assert epoch_loop.batch_progress.total.completed == 2 + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 2 + assert optimizer_loop.restarting # reset state loaded from a checkpoint from the end of an epoch @@ -592,19 +607,21 @@ def test_fit_loop_reset(tmp_path): fit_loop.reset() epoch_loop.reset() + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes + assert fit_loop.epoch_progress.total.completed == 0 assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 + # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 3 - assert epoch_loop.batch_progress.current.completed == 3 + assert epoch_loop.batch_progress.total.completed == 4 + assert epoch_loop.batch_progress.current.ready == 0 + assert epoch_loop.batch_progress.current.processed == 0 + assert epoch_loop.batch_progress.current.completed == 0 def compare_state_dicts(dict1, dict2): From e8bd2d79a02b117474ad243d3dcabe38b207ae1f Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 6 Nov 2024 13:36:25 +0000 Subject: [PATCH 06/16] Avoid skipping to val end if saved mid validation --- src/lightning/pytorch/loops/evaluation_loop.py | 9 +++++++++ src/lightning/pytorch/loops/training_epoch_loop.py | 2 ++ tests/tests_pytorch/loops/test_loops.py | 7 ++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 7eaf7cf50b8a0..329ddaabd5fb9 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -197,6 +197,15 @@ def setup_data(self) -> None: # this depends on the data used, so reset it too self._seen_batches_per_dataloader = defaultdict(int) + @property + def restarting_mid_evaluation(self) -> bool: + return ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started - 1 + and self.batch_progress.total.completed == self.batch_progress.total.processed + ) + @property def restarting_on_evaluation_end(self) -> bool: return ( diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 5fb98a82b9950..440643c24cb7b 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -217,6 +217,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): + if self.val_loop.restarting_mid_evaluation: + return # fast forward progress counters to end of validation self.val_loop.increment_progress_to_evaluation_end() diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index d5faf07425215..7c0f8694d04ac 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -397,12 +397,13 @@ def training_step(self, batch, batch_idx): assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) - # test resetting manually, we expect all `ready` counters to be reset to `completed` + # test resetting manually, we expect the `ready` counter for batch to be reset to `completed` + # but the `ready` counter for epoch to not be reset, since we are still mid epoch trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress - assert epoch_progress.current.ready == stop_epoch + assert epoch_progress.current.ready == stop_epoch + 1 assert epoch_progress.current.completed == stop_epoch batch_progress = trainer.fit_loop.epoch_loop.batch_progress @@ -418,7 +419,7 @@ def training_step(self, batch, batch_idx): state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 - assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + 1 def test_loop_state_on_complete_run(tmp_path): From 0012dcb2bfb58dc6c499d4e7b8094f09d0512198 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 6 Nov 2024 13:39:30 +0000 Subject: [PATCH 07/16] Fix type checks in compare state dicts --- tests/tests_pytorch/loops/test_loops.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 7c0f8694d04ac..95fd75d7c6674 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Iterator, Any +from typing import Any, Dict, Iterator from unittest.mock import ANY, Mock import pytest @@ -575,7 +575,7 @@ def test_fit_loop_reset(tmp_path): fit_loop.reset() epoch_loop.reset() - + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 @@ -629,27 +629,28 @@ def compare_state_dicts(dict1, dict2): def compare_leaves(d1, d2): result = {} all_keys = set(d1.keys()).union(d2.keys()) - + for key in all_keys: val1 = d1.get(key, None) val2 = d2.get(key, None) - + if isinstance(val1, dict) and isinstance(val2, dict): res = compare_leaves(val1, val2) if res: result[key] = res elif isinstance(val1, dict) or isinstance(val2, dict): raise ValueError("dicts have different leaves") - elif type(val1) == torch.Tensor and type(val2) == torch.Tensor: + elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): diff = torch.norm(val1 - val2) if diff > 1e-8: result[key] = f"{diff} > 1e-8" - elif type(val1) == float and type(val2) == float: + elif isinstance(val1, float) and isinstance(val2, float): if abs(val1 - val2) > 1e-8: result[key] = f"{val1} != {val2}" elif val1 != val2: result[key] = f"{val1} != {val2}" return result + return compare_leaves(dict1, dict2) @@ -718,7 +719,7 @@ def test_restart_parity(tmp_path): trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) loss_v1 = model.last_loss - assert(abs(loss - loss_v1) < 1e-8) + assert abs(loss - loss_v1) < 1e-8 end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) @@ -783,7 +784,7 @@ def test_restart_parity_with_val(tmp_path): trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) loss_v1 = model.last_loss - assert(abs(loss - loss_v1) < 1e-8) + assert abs(loss - loss_v1) < 1e-8 end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) From e64c200a0e30b8dd96ff554f29cd59e207d6425b Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Thu, 7 Nov 2024 18:44:10 +0000 Subject: [PATCH 08/16] Fix edge cases and start from last with and without val --- .../pytorch/loops/evaluation_loop.py | 9 -- src/lightning/pytorch/loops/fit_loop.py | 29 +++++ .../pytorch/loops/training_epoch_loop.py | 46 +++++-- tests/tests_pytorch/loops/test_loops.py | 118 +++++++++++++++++- 4 files changed, 181 insertions(+), 21 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 329ddaabd5fb9..0b800b9046fef 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -206,15 +206,6 @@ def restarting_mid_evaluation(self) -> bool: and self.batch_progress.total.completed == self.batch_progress.total.processed ) - @property - def restarting_on_evaluation_end(self) -> bool: - return ( - self.restarting - and self.batch_progress.total.started == self.batch_progress.total.ready - and self.batch_progress.total.processed == self.batch_progress.total.started - and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 - ) - def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index e1b9d42f6e09e..908677a7d1fcd 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -329,14 +329,43 @@ def restarting_on_epoch_end(self) -> bool: and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 ) + @property + def progress_at_epoch_end(self) -> bool: + # TODO LUCA comment for restart last without val + return ( + self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ) + def reset(self) -> None: """Resets the internal state of this loop.""" assert self.trainer.model is not None torch.set_grad_enabled(True) + self.epoch_loop.reset_restarting_states() + if self.restarting_on_epoch_start: self.epoch_progress.reset_on_restart() + if self.progress_at_epoch_end: + self.epoch_progress.increment_completed() + + # TODO LUCA: refactor restarting for fit_loop + restarting_mid_epoch = self.restarting_mid_epoch + + if (self.epoch_loop.restarting_on_train_batch_end + and self.restarting_mid_epoch + and self.epoch_loop.batch_progress.is_last_batch): + self.epoch_progress.increment_processed() + self.epoch_progress.increment_completed() + + if (self.epoch_loop.restarting_on_train_batch_end + and self.epoch_loop.batch_progress.is_last_batch + and not restarting_mid_epoch + and not self.epoch_loop.val_loop.batch_progress.is_last_batch): + self.epoch_progress.increment_completed() + def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" # update the current_epoch in-case of checkpoint reload diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 440643c24cb7b..9e14b12e04139 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -81,6 +81,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._results = _ResultCollection(training=True) self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 + self._restarting_on_train_batch_end: bool = None + self._restarting_on_last: bool = None @property def total_batch_idx(self) -> int: @@ -146,15 +148,43 @@ def run(self, data_fetcher: _DataFetcher) -> None: @property def restarting_on_train_batch_end(self) -> bool: - return ( - self.restarting - and self.batch_progress.total.started == self.batch_progress.total.ready - and self.batch_progress.total.processed == self.batch_progress.total.started - and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 - ) + if self._restarting_on_train_batch_end is None: + self._restarting_on_train_batch_end = ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ) + return self._restarting_on_train_batch_end + + @property + def restarting_on_last(self) -> bool: + if self._restarting_on_last is None: + self._restarting_on_last = ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed + ) + return self._restarting_on_last + + def reset_restarting_states(self) -> None: + self._restarting_on_train_batch_end = None + self._restarting_on_last = None + self.restarting_on_train_batch_end + self.restarting_on_last def reset(self) -> None: + self.reset_restarting_states() """Resets the internal state of the loop for a new run.""" + if self.restarting and not self._should_accumulate(): + # batches_that_stepped is never set prior to saving a checkpoint, even when saving + # happens on_validation_end + # we could set it in the checkpoint but we prefer to keep checkpoints backward compatible + if self.restarting_on_train_batch_end or not self.restarting_on_last: + # if not self.restarting_on_train_batch_end and not self.restarting_on_last: + self._batches_that_stepped += 1 + if self.restarting_on_train_batch_end: self.batch_progress.increment_completed() # handle situation in which save happened on_train_batch_end and epoch is at end @@ -163,8 +193,6 @@ def reset(self) -> None: self.scheduler_progress.reset_on_run() self.automatic_optimization.optim_progress.reset_on_run() self.val_loop.batch_progress.total.reset() - if not self._should_accumulate(): - self._batches_that_stepped += 1 if self.restarting: self.batch_progress.reset_on_restart() @@ -217,7 +245,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - if self.val_loop.restarting_mid_evaluation: + if self.val_loop.restarting_mid_evaluation or self.restarting_on_last: return # fast forward progress counters to end of validation self.val_loop.increment_progress_to_evaluation_end() diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 95fd75d7c6674..8d94275b5b245 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -604,16 +604,24 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + # since we are restarting at the end of epoch, we need to see `completed` being updated after reset assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 + assert fit_loop.epoch_progress.total.completed == 1 assert fit_loop.epoch_progress.current.ready == 1 - assert fit_loop.epoch_progress.current.completed == 0 + assert fit_loop.epoch_progress.current.completed == 1 # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting @@ -704,6 +712,7 @@ def test_restart_parity(tmp_path): callbacks=[checkpoint_callback], logger=False, enable_model_summary=False, + enable_progress_bar=False, ) trainer.fit(model) loss = model.last_loss @@ -715,6 +724,7 @@ def test_restart_parity(tmp_path): callbacks=[checkpoint_callback], logger=False, enable_model_summary=False, + enable_progress_bar=False, ) trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) loss_v1 = model.last_loss @@ -749,7 +759,7 @@ def test_restart_parity(tmp_path): assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} -def test_restart_parity_with_val(tmp_path): +def test_restart_with_val_parity(tmp_path): model = PredictableBoringModel() checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, @@ -814,6 +824,108 @@ def test_restart_parity_with_val(tmp_path): assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} +def test_restart_from_last_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} + + +def test_restart_from_last_with_val_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} + + @pytest.mark.parametrize( ("train_datasets", "val_datasets"), [([RandomDataset], [RandomDataset]), ([RandomDataset], [RandomDataset, RandomDataset])], From 59e41de707cd76c3450a677bc889c166f997ec6c Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 8 Nov 2024 10:38:25 +0000 Subject: [PATCH 09/16] Clean up --- .../pytorch/loops/evaluation_loop.py | 25 ++++- src/lightning/pytorch/loops/fit_loop.py | 104 ++++++++++++------ src/lightning/pytorch/loops/loop.py | 10 ++ .../pytorch/loops/training_epoch_loop.py | 80 ++++++++------ 4 files changed, 146 insertions(+), 73 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0b800b9046fef..277deddba2a50 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,6 +15,7 @@ import shutil import sys from collections import ChainMap, OrderedDict, defaultdict +from dataclasses import dataclass from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection @@ -45,6 +46,12 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_MID_EVALUATION = "restarted_mid_evaluation" + + class _EvaluationLoop(_Loop): """Top-level loop where validation/testing starts.""" @@ -73,6 +80,7 @@ def __init__( self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() + self._restart_stage = RestartStage.NONE @property def num_dataloaders(self) -> int: @@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]: # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support break finally: - self._restarting = False + self.on_iteration_done() self._store_dataloader_outputs() return self.on_run_end() @@ -198,14 +206,23 @@ def setup_data(self) -> None: self._seen_batches_per_dataloader = defaultdict(int) @property - def restarting_mid_evaluation(self) -> bool: - return ( + def restarted_mid_evaluation(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION + + def update_restart_stage(self) -> None: + if ( self.restarting and self.batch_progress.total.started == self.batch_progress.total.ready and self.batch_progress.total.processed == self.batch_progress.total.started - 1 and self.batch_progress.total.completed == self.batch_progress.total.processed - ) + ): + self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION + else: + self._restart_stage = RestartStage.NONE + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE + def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 908677a7d1fcd..e20088acd0af3 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import torch @@ -45,6 +46,15 @@ log = logging.getLogger(__name__) +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" + RESTARTED_MID_EPOCH = "restarted_mid_epoch" + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" + + class _FitLoop(_Loop): """This loop is the top-level loop where training starts. @@ -97,6 +107,7 @@ def __init__( self._combined_loader_states_to_load: List[Dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") + self._restart_stage = RestartStage.NONE @property def total_batch_idx(self) -> int: @@ -204,9 +215,10 @@ def run(self) -> None: self.on_advance_start() self.advance() self.on_advance_end() - self._restarting = False except StopIteration: break + finally: + self.on_iteration_done() self._restarting = False self.on_run_end() @@ -303,67 +315,89 @@ def setup_data(self) -> None: ) @property - def restarting_on_epoch_start(self) -> bool: - return ( + def restarted_on_epoch_start(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START + + @property + def restarted_mid_epoch(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH + + @property + def restarted_on_epoch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END + + @property + def resumed_on_epoch_end(self) -> bool: + # This case happens when restarting from last without validation at + # the end of epoch. In this case self.restarting is False. + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END + + def update_restart_stage(self) -> None: + if ( self.restarting and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 and self.epoch_progress.total.processed == self.epoch_progress.total.started and self.epoch_progress.total.completed == self.epoch_progress.total.processed - ) - - @property - def restarting_mid_epoch(self) -> bool: - return ( + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START + elif ( self.restarting and self.epoch_progress.total.started == self.epoch_progress.total.ready and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 and self.epoch_progress.total.completed == self.epoch_progress.total.processed - ) - - @property - def restarting_on_epoch_end(self) -> bool: - return ( + ): + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH + elif ( self.restarting and self.epoch_progress.total.started == self.epoch_progress.total.ready and self.epoch_progress.total.processed == self.epoch_progress.total.started and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 - ) - - @property - def progress_at_epoch_end(self) -> bool: - # TODO LUCA comment for restart last without val - return ( - self.epoch_progress.total.started == self.epoch_progress.total.ready + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END + elif ( + self._loaded_from_state_dict + and self.epoch_progress.total.started == self.epoch_progress.total.ready and self.epoch_progress.total.processed == self.epoch_progress.total.started and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 - ) + ): + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END + else: + self._restart_stage = RestartStage.NONE + + self.epoch_loop.update_restart_stage() + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE def reset(self) -> None: """Resets the internal state of this loop.""" assert self.trainer.model is not None torch.set_grad_enabled(True) - self.epoch_loop.reset_restarting_states() + self.update_restart_stage() - if self.restarting_on_epoch_start: + if self.restarted_on_epoch_start: self.epoch_progress.reset_on_restart() - if self.progress_at_epoch_end: + if self.resumed_on_epoch_end: + # when restarting from last without validation at end of epoch, + # self.restarting is False but it's still resuming self.epoch_progress.increment_completed() - # TODO LUCA: refactor restarting for fit_loop - restarting_mid_epoch = self.restarting_mid_epoch - - if (self.epoch_loop.restarting_on_train_batch_end - and self.restarting_mid_epoch - and self.epoch_loop.batch_progress.is_last_batch): + if ( + self.epoch_loop.restarted_on_train_batch_end + and self.restarted_mid_epoch + and self.epoch_loop.batch_progress.is_last_batch + ): self.epoch_progress.increment_processed() self.epoch_progress.increment_completed() - if (self.epoch_loop.restarting_on_train_batch_end + if ( + self.epoch_loop.restarted_on_train_batch_end and self.epoch_loop.batch_progress.is_last_batch - and not restarting_mid_epoch - and not self.epoch_loop.val_loop.batch_progress.is_last_batch): + and not self.restarted_mid_epoch + and not self.epoch_loop.val_loop.batch_progress.is_last_batch + ): self.epoch_progress.increment_completed() def on_run_start(self) -> None: @@ -396,8 +430,8 @@ def on_advance_start(self) -> None: for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed) - if not self.restarting_mid_epoch and not self.restarting_on_epoch_end: - if not self.restarting_on_epoch_start: + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: + if not self.restarted_on_epoch_start: self.epoch_progress.increment_ready() call._call_callback_hooks(trainer, "on_train_epoch_start") diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 56d520800c447..111377a222b3f 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -22,6 +22,7 @@ class _Loop: def __init__(self, trainer: "pl.Trainer") -> None: self._restarting = False + self._loaded_from_state_dict = False self.trainer = trainer @property @@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None: if isinstance(loop, _Loop): loop.restarting = restarting + def reset_restart_stage(self) -> None: + pass + def on_save_checkpoint(self) -> Dict: """Called when saving a model checkpoint, use to persist loop state. @@ -82,6 +86,7 @@ def load_state_dict( if isinstance(v, _Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True + self._loaded_from_state_dict = True def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: for k, v in self.__dict__.items(): @@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: v.load_state_dict(state_dict[key]) if prefix + "state_dict" in state_dict: # compatibility with old checkpoints self.on_load_checkpoint(state_dict[prefix + "state_dict"]) + + def on_iteration_done(self) -> None: + self._restarting = False + self._loaded_from_state_dict = False + self.reset_restart_stage() diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e14b12e04139..6d8f165be812f 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import math from collections import OrderedDict +from dataclasses import dataclass from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -37,6 +38,13 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_TRAIN_BATCH_END = "restarted_on_train_batch_end" + RESTARTED_ON_LAST = "restarted_on_last" + + class _TrainingEpochLoop(loops._Loop): """Iterates over all batches in the dataloader (one epoch) that the user returns in their :meth:`~lightning.pytorch.core.LightningModule.train_dataloader` method. @@ -81,8 +89,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._results = _ResultCollection(training=True) self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 - self._restarting_on_train_batch_end: bool = None - self._restarting_on_last: bool = None + self._restart_stage = RestartStage.NONE @property def total_batch_idx(self) -> int: @@ -141,51 +148,56 @@ def run(self, data_fetcher: _DataFetcher) -> None: try: self.advance(data_fetcher) self.on_advance_end(data_fetcher) - self._restarting = False except StopIteration: break - self._restarting = False + finally: + self.on_iteration_done() @property - def restarting_on_train_batch_end(self) -> bool: - if self._restarting_on_train_batch_end is None: - self._restarting_on_train_batch_end = ( - self.restarting - and self.batch_progress.total.started == self.batch_progress.total.ready - and self.batch_progress.total.processed == self.batch_progress.total.started - and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 - ) - return self._restarting_on_train_batch_end + def restarted_on_train_batch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_TRAIN_BATCH_END @property - def restarting_on_last(self) -> bool: - if self._restarting_on_last is None: - self._restarting_on_last = ( - self.restarting - and self.batch_progress.total.started == self.batch_progress.total.ready - and self.batch_progress.total.processed == self.batch_progress.total.started - and self.batch_progress.total.completed == self.batch_progress.total.processed - ) - return self._restarting_on_last + def restarted_on_last(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_LAST + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESTARTED_ON_TRAIN_BATCH_END + elif ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_ON_LAST + else: + self._restart_stage = RestartStage.NONE + + self.val_loop.update_restart_stage() - def reset_restarting_states(self) -> None: - self._restarting_on_train_batch_end = None - self._restarting_on_last = None - self.restarting_on_train_batch_end - self.restarting_on_last + def reset_restart_stage(self): + self._restart_stage = RestartStage.NONE def reset(self) -> None: - self.reset_restarting_states() """Resets the internal state of the loop for a new run.""" - if self.restarting and not self._should_accumulate(): + if ( + self.restarting + and not self._should_accumulate() + and self.restarted_on_train_batch_end + or not self.restarted_on_last + ): # batches_that_stepped is never set prior to saving a checkpoint, even when saving # happens on_validation_end # we could set it in the checkpoint but we prefer to keep checkpoints backward compatible - if self.restarting_on_train_batch_end or not self.restarting_on_last: - # if not self.restarting_on_train_batch_end and not self.restarting_on_last: - self._batches_that_stepped += 1 + self._batches_that_stepped += 1 - if self.restarting_on_train_batch_end: + if self.restarted_on_train_batch_end: self.batch_progress.increment_completed() # handle situation in which save happened on_train_batch_end and epoch is at end if self.batch_progress.current.completed >= self.trainer.num_training_batches: @@ -245,7 +257,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - if self.val_loop.restarting_mid_evaluation or self.restarting_on_last: + if self.val_loop.restarted_mid_evaluation or self.restarted_on_last: return # fast forward progress counters to end of validation self.val_loop.increment_progress_to_evaluation_end() From c3469beb66fd1045b16f153400dba03d0fc10700 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 8 Nov 2024 10:38:54 +0000 Subject: [PATCH 10/16] Formatting --- src/lightning/pytorch/loops/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 277deddba2a50..3a15be7973fb1 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -222,7 +222,7 @@ def update_restart_stage(self) -> None: def reset_restart_stage(self) -> None: self._restart_stage = RestartStage.NONE - + def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer From 9c58810e6ba9c04b266da73200cadde9b0680db9 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 11 Nov 2024 20:00:29 +0100 Subject: [PATCH 11/16] Avoid running validation when restarting from last --- .../pytorch/loops/training_epoch_loop.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 6d8f165be812f..bbde24b761399 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -90,6 +90,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 self._restart_stage = RestartStage.NONE + self._skip_next_val = False @property def total_batch_idx(self) -> int: @@ -257,8 +258,15 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - if self.val_loop.restarted_mid_evaluation or self.restarted_on_last: + if self.val_loop.restarted_mid_evaluation: + # Go back and finish running validation return + + if self.restarted_on_last: + # Avoid running validation again if we saved on last + self._skip_next_val = True + return + # fast forward progress counters to end of validation self.val_loop.increment_progress_to_evaluation_end() @@ -345,6 +353,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # VALIDATE IF NEEDED # ----------------------------------------- should_check_val = self._should_check_val_fx(data_fetcher) + + if self._skip_next_val: + should_check_val = False + self._skip_next_val = False + if should_check_val: # this needs to be set so the correct `trainer._active_loop` is picked self.trainer.validating = True From fae3aa505c4fe617e3207cf9a57aa02ca24da5f6 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 11 Nov 2024 20:22:38 +0100 Subject: [PATCH 12/16] Fix type annotations --- src/lightning/pytorch/loops/progress.py | 12 +++++------- src/lightning/pytorch/loops/training_epoch_loop.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 807d150a26fb2..736fbd11803cf 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -59,7 +59,6 @@ def reset(self) -> None: self.ready = 0 self.completed = 0 - @override def reset_on_restart(self) -> None: """Reset the progress on restart. @@ -69,8 +68,7 @@ def reset_on_restart(self) -> None: """ self.ready = self.completed - @override - def increment_by(self, n) -> None: + def increment_by(self, n: int) -> None: self.ready += n self.completed += n @@ -101,7 +99,7 @@ def reset_on_restart(self) -> None: self.started = self.completed @override - def increment_by(self, n) -> None: + def increment_by(self, n: int) -> None: super().increment_by(n) self.started += n @@ -133,7 +131,7 @@ def reset_on_restart(self) -> None: self.processed = self.completed @override - def increment_by(self, n) -> None: + def increment_by(self, n: int) -> None: super().increment_by(n) self.processed += n @@ -192,7 +190,7 @@ def reset_on_restart(self) -> None: self.current.reset_on_restart() @override - def increment_by(self, n) -> None: + def increment_by(self, n: int) -> None: self.total.increment_by(n) self.current.increment_by(n) @@ -227,7 +225,7 @@ def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False - def increment_by(self, n, is_last_batch=False) -> None: + def increment_by(self, n: int, is_last_batch: bool=False) -> None: super().increment_by(n) self.is_last_batch = is_last_batch diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index bbde24b761399..08685c227c598 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -182,7 +182,7 @@ def update_restart_stage(self) -> None: self.val_loop.update_restart_stage() - def reset_restart_stage(self): + def reset_restart_stage(self) -> None: self._restart_stage = RestartStage.NONE def reset(self) -> None: From c111f25509df0e3158a38b9ad7c93111a9cd4cd6 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 11 Nov 2024 20:23:02 +0100 Subject: [PATCH 13/16] Fix formatting --- src/lightning/pytorch/loops/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 736fbd11803cf..5d8747d272324 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -225,7 +225,7 @@ def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False - def increment_by(self, n: int, is_last_batch: bool=False) -> None: + def increment_by(self, n: int, is_last_batch: bool = False) -> None: super().increment_by(n) self.is_last_batch = is_last_batch From bf6b17d407c7dfb76947170e06904d446bdf3ce0 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 11 Nov 2024 20:33:08 +0100 Subject: [PATCH 14/16] Ensure int max_batch --- src/lightning/pytorch/loops/evaluation_loop.py | 4 +++- src/lightning/pytorch/loops/progress.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 3a15be7973fb1..78573c45aab76 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -267,7 +267,9 @@ def increment_progress_to_evaluation_end(self) -> None: if self.skip: return self.reset() - max_batch = max(self.max_batches) + max_batch = int(max(self.max_batches)) + if max_batch == -1: + return self.batch_progress.increment_by(max_batch, True) def on_run_start(self) -> None: diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 5d8747d272324..6880b24f70c65 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -189,7 +189,6 @@ def reset_on_run(self) -> None: def reset_on_restart(self) -> None: self.current.reset_on_restart() - @override def increment_by(self, n: int) -> None: self.total.increment_by(n) self.current.increment_by(n) From c59ab400a20b8ea14ebc5356b505b728d4572170 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 12 Nov 2024 13:26:05 +0100 Subject: [PATCH 15/16] Fix condition on batches that stepped --- src/lightning/pytorch/loops/training_epoch_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 08685c227c598..1c749de3a1b6d 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -190,8 +190,7 @@ def reset(self) -> None: if ( self.restarting and not self._should_accumulate() - and self.restarted_on_train_batch_end - or not self.restarted_on_last + and (self.restarted_on_train_batch_end or not self.restarted_on_last) ): # batches_that_stepped is never set prior to saving a checkpoint, even when saving # happens on_validation_end From 9ad4200251af82b4b41548bca6a122481c991867 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 12 Nov 2024 14:56:35 +0100 Subject: [PATCH 16/16] Remove expected on_train_epoch_start when restarting mid epoch --- tests/tests_pytorch/models/test_hooks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 5a175e181dd9e..685bd6c0bdaef 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -660,8 +660,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): {"name": "train_dataloader"}, {"name": "Callback.on_train_start", "args": (trainer, model)}, {"name": "on_train_start"}, - {"name": "Callback.on_train_epoch_start", "args": (trainer, model)}, - {"name": "on_train_epoch_start"}, *model._train_batch(trainer, model, steps_after_reload, trainer.strategy.root_device, current_batch=1), {"name": "Callback.on_train_epoch_end", "args": (trainer, model)}, {"name": "on_train_epoch_end"}, # before ModelCheckpoint because it's a "monitoring callback"