Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure restarting from checkpoints leads to consistent internal counters #20379

Merged
merged 17 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
import sys
from collections import ChainMap, OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union

from lightning_utilities.core.apply_func import apply_to_collection
Expand Down Expand Up @@ -45,6 +46,12 @@
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature


@dataclass
class RestartStage:
lantiga marked this conversation as resolved.
Show resolved Hide resolved
NONE = "none"
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"


class _EvaluationLoop(_Loop):
"""Top-level loop where validation/testing starts."""

Expand Down Expand Up @@ -73,6 +80,7 @@ def __init__(
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
self._last_val_dl_reload_epoch = float("-inf")
self._module_mode = _ModuleMode()
self._restart_stage = RestartStage.NONE

@property
def num_dataloaders(self) -> int:
Expand Down Expand Up @@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]:
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
break
finally:
self._restarting = False
self.on_iteration_done()
self._store_dataloader_outputs()
return self.on_run_end()

Expand Down Expand Up @@ -197,6 +205,24 @@ def setup_data(self) -> None:
# this depends on the data used, so reset it too
self._seen_batches_per_dataloader = defaultdict(int)

@property
def restarted_mid_evaluation(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION

def update_restart_stage(self) -> None:
if (
self.restarting
and self.batch_progress.total.started == self.batch_progress.total.ready
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
and self.batch_progress.total.completed == self.batch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
else:
self._restart_stage = RestartStage.NONE

def reset_restart_stage(self) -> None:
self._restart_stage = RestartStage.NONE

def reset(self) -> None:
"""Resets the internal state of the loop."""
trainer = self.trainer
Expand Down Expand Up @@ -236,6 +262,16 @@ def reset(self) -> None:
data_fetcher._stop_profiler = self._on_after_fetch
self._data_fetcher = data_fetcher

def increment_progress_to_evaluation_end(self) -> None:
self.setup_data()
if self.skip:
return
self.reset()
max_batch = int(max(self.max_batches))
if max_batch == -1:
return
self.batch_progress.increment_by(max_batch, True)

def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
Expand Down
107 changes: 99 additions & 8 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -45,6 +46,15 @@
log = logging.getLogger(__name__)


@dataclass
class RestartStage:
NONE = "none"
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"


class _FitLoop(_Loop):
"""This loop is the top-level loop where training starts.

Expand Down Expand Up @@ -97,6 +107,7 @@ def __init__(
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")
self._restart_stage = RestartStage.NONE

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -204,9 +215,10 @@ def run(self) -> None:
self.on_advance_start()
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
finally:
self.on_iteration_done()
self._restarting = False
self.on_run_end()

Expand Down Expand Up @@ -302,14 +314,92 @@ def setup_data(self) -> None:
category=PossibleUserWarning,
)

@property
def restarted_on_epoch_start(self) -> bool:
lantiga marked this conversation as resolved.
Show resolved Hide resolved
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START

@property
def restarted_mid_epoch(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH

@property
def restarted_on_epoch_end(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END

@property
def resumed_on_epoch_end(self) -> bool:
# This case happens when restarting from last without validation at
# the end of epoch. In this case self.restarting is False.
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END

def update_restart_stage(self) -> None:
if (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
elif (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
elif (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
):
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
elif (
self._loaded_from_state_dict
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
):
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
else:
self._restart_stage = RestartStage.NONE

self.epoch_loop.update_restart_stage()

def reset_restart_stage(self) -> None:
self._restart_stage = RestartStage.NONE

def reset(self) -> None:
"""Resets the internal state of this loop."""
assert self.trainer.model is not None
torch.set_grad_enabled(True)

if self.restarting:
self.update_restart_stage()

if self.restarted_on_epoch_start:
self.epoch_progress.reset_on_restart()

if self.resumed_on_epoch_end:
# when restarting from last without validation at end of epoch,
# self.restarting is False but it's still resuming
self.epoch_progress.increment_completed()

if (
self.epoch_loop.restarted_on_train_batch_end
and self.restarted_mid_epoch
and self.epoch_loop.batch_progress.is_last_batch
):
self.epoch_progress.increment_processed()
self.epoch_progress.increment_completed()

if (
self.epoch_loop.restarted_on_train_batch_end
and self.epoch_loop.batch_progress.is_last_batch
and not self.restarted_mid_epoch
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
):
self.epoch_progress.increment_completed()

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
# update the current_epoch in-case of checkpoint reload
Expand Down Expand Up @@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
for i, dl in enumerate(self._combined_loader.flattened):
_set_sampler_epoch(dl, self.epoch_progress.current.processed)

self.epoch_progress.increment_ready()
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
if not self.restarted_on_epoch_start:
self.epoch_progress.increment_ready()

call._call_callback_hooks(trainer, "on_train_epoch_start")
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
call._call_callback_hooks(trainer, "on_train_epoch_start")
call._call_lightning_module_hook(trainer, "on_train_epoch_start")

self.epoch_progress.increment_started()
self.epoch_progress.increment_started()

def advance(self) -> None:
"""Runs one whole epoch."""
Expand Down Expand Up @@ -379,8 +471,7 @@ def on_advance_end(self) -> None:

trainer._logger_connector.on_epoch_end()

if self.epoch_loop._num_ready_batches_reached():
# if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint.
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
# since metric-based schedulers require access to metrics and those are not currently saved in the
# checkpoint, the plateau schedulers shouldn't be updated
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/pytorch/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _Loop:

def __init__(self, trainer: "pl.Trainer") -> None:
self._restarting = False
self._loaded_from_state_dict = False
self.trainer = trainer

@property
Expand All @@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None:
if isinstance(loop, _Loop):
loop.restarting = restarting

def reset_restart_stage(self) -> None:
pass

def on_save_checkpoint(self) -> Dict:
"""Called when saving a model checkpoint, use to persist loop state.

Expand Down Expand Up @@ -82,6 +86,7 @@ def load_state_dict(
if isinstance(v, _Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".")
self.restarting = True
self._loaded_from_state_dict = True

def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
for k, v in self.__dict__.items():
Expand All @@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
v.load_state_dict(state_dict[key])
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
self.on_load_checkpoint(state_dict[prefix + "state_dict"])

def on_iteration_done(self) -> None:
self._restarting = False
self._loaded_from_state_dict = False
self.reset_restart_stage()
22 changes: 22 additions & 0 deletions src/lightning/pytorch/loops/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def reset_on_restart(self) -> None:
"""
self.ready = self.completed

def increment_by(self, n: int) -> None:
self.ready += n
self.completed += n


@dataclass
class _StartedTracker(_ReadyCompletedTracker):
Expand All @@ -94,6 +98,11 @@ def reset_on_restart(self) -> None:
super().reset_on_restart()
self.started = self.completed

@override
def increment_by(self, n: int) -> None:
super().increment_by(n)
self.started += n


@dataclass
class _ProcessedTracker(_StartedTracker):
Expand Down Expand Up @@ -121,6 +130,11 @@ def reset_on_restart(self) -> None:
super().reset_on_restart()
self.processed = self.completed

@override
def increment_by(self, n: int) -> None:
super().increment_by(n)
self.processed += n


@dataclass
class _Progress(_BaseProgress):
Expand Down Expand Up @@ -175,6 +189,10 @@ def reset_on_run(self) -> None:
def reset_on_restart(self) -> None:
self.current.reset_on_restart()

def increment_by(self, n: int) -> None:
self.total.increment_by(n)
self.current.increment_by(n)

@override
def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
Expand Down Expand Up @@ -206,6 +224,10 @@ def reset_on_run(self) -> None:
super().reset_on_run()
self.is_last_batch = False

def increment_by(self, n: int, is_last_batch: bool = False) -> None:
super().increment_by(n)
self.is_last_batch = is_last_batch

@override
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
Expand Down
Loading
Loading