From 789afa8d80397c5c10278126074a01016729ea94 Mon Sep 17 00:00:00 2001 From: sudipto baral Date: Mon, 28 Apr 2025 12:14:57 -0400 Subject: [PATCH 1/5] Fix double iteration bug when resumed from a checkpoint. --- src/lightning/pytorch/loops/fit_loop.py | 5 ++ .../pytorch/loops/training_epoch_loop.py | 6 +- .../test_double_iter_in_iterable_dataset.py | 81 +++++++++++++++++++ 3 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724a043a3..065237c5dffcd 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -204,6 +204,11 @@ def skip(self) -> bool: # so we cannot use it solely return self.done or self.trainer.limit_train_batches == 0 + @property + def _is_resuming(self) -> bool: + """Whether we're resuming training from a checkpoint.""" + return self._loaded_from_state_dict + def run(self) -> None: self.setup_data() if self.skip: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 599eccdc8ca91..75b5cdf02faba 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -237,7 +237,11 @@ def reset(self) -> None: def on_run_start(self, data_fetcher: _DataFetcher) -> None: # `iter()` was called once in `FitLoop.setup_data()` already - if self.trainer.current_epoch > 0 and not self.restarting: + # Only call iter() if: + # 1. Not restarting AND + # 2. Not resuming from checkpoint (not _is_resuming) AND + # 3. Past first epoch (current_epoch > 0) + if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop._is_resuming) and not self.restarting: iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching diff --git a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py new file mode 100644 index 0000000000000..338947b878668 --- /dev/null +++ b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py @@ -0,0 +1,81 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This test tests the resuming of training from a checkpoint file using an IterableDataset. +# And contains code mentioned in the issue: #19427. +# Ref: https://github.com/Lightning-AI/pytorch-lightning/issues/19427 +import multiprocessing as mp +import os +from collections.abc import Iterator +from pathlib import Path +from queue import Queue + +import numpy as np +from torch.utils.data import DataLoader, IterableDataset + +from lightning import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class QueueDataset(IterableDataset): + def __init__(self, queue: Queue) -> None: + super().__init__() + self.queue = queue + + def __iter__(self) -> Iterator: + for _ in range(5): + tensor, _ = self.queue.get(timeout=5) + yield tensor + + +def create_queue() -> Queue: + q = mp.Queue() + arr = np.random.random([1, 32]).astype(np.float32) + for ind in range(20): + q.put((arr, ind)) + return q + + +def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer: + dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True) + trainer = Trainer( + max_epochs=max_epochs, + enable_progress_bar=False, + enable_checkpointing=False, + devices=1, + logger=False, + ) + if ckpt_path.exists(): + trainer.fit(BoringModel(), dataloader, ckpt_path=str(ckpt_path)) + else: + trainer.fit(BoringModel(), dataloader) + trainer.save_checkpoint(str(ckpt_path)) + return trainer + + +def test_resume_training_with(tmp_path): + """Test resuming training from checkpoint file using a IterableDataset.""" + queue = create_queue() + max_epoch = 2 + ckpt_path = tmp_path / "model.ckpt" + trainer = train_model(queue, max_epoch, ckpt_path) + assert trainer is not None + + assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created" + + ckpt_size = os.path.getsize(ckpt_path) + assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)" + + trainer = train_model(queue, max_epoch + 2, ckpt_path) + assert trainer is not None From 1cb6a6b171bddcedfaf2f9b3856ac2c1d9419611 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 30 Apr 2025 07:34:20 +0200 Subject: [PATCH 2/5] Apply suggestions from code review --- .../tests_pytorch/loops/test_double_iter_in_iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py index 338947b878668..1deb169da7206 100644 --- a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py +++ b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py @@ -24,7 +24,7 @@ import numpy as np from torch.utils.data import DataLoader, IterableDataset -from lightning import Trainer +from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel From 5a0c70c7f9082e93e2ebbb67747f411fc451bae4 Mon Sep 17 00:00:00 2001 From: sudipto baral Date: Wed, 14 May 2025 00:37:52 -0400 Subject: [PATCH 3/5] update wording in the comments. Signed-off-by: sudipto baral --- src/lightning/pytorch/loops/training_epoch_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 75b5cdf02faba..eba38d0728091 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -237,9 +237,9 @@ def reset(self) -> None: def on_run_start(self, data_fetcher: _DataFetcher) -> None: # `iter()` was called once in `FitLoop.setup_data()` already - # Only call iter() if: - # 1. Not restarting AND - # 2. Not resuming from checkpoint (not _is_resuming) AND + # Only call `iter()` if all following cases: + # 1. Not restarting + # 2. Not resuming from checkpoint (not _is_resuming) # 3. Past first epoch (current_epoch > 0) if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop._is_resuming) and not self.restarting: iter(data_fetcher) # creates the iterator inside the fetcher From df95a0c82430b57af33875bb17ad322adc4cf6f5 Mon Sep 17 00:00:00 2001 From: sudipto baral Date: Wed, 14 May 2025 00:51:49 -0400 Subject: [PATCH 4/5] update test Signed-off-by: sudipto baral --- .../loops/test_double_iter_in_iterable_dataset.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py index 1deb169da7206..9ae051bdc6da2 100644 --- a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py +++ b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py @@ -47,7 +47,7 @@ def create_queue() -> Queue: return q -def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer: +def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> None: dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True) trainer = Trainer( max_epochs=max_epochs, @@ -61,7 +61,6 @@ def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> Trainer: else: trainer.fit(BoringModel(), dataloader) trainer.save_checkpoint(str(ckpt_path)) - return trainer def test_resume_training_with(tmp_path): @@ -69,13 +68,10 @@ def test_resume_training_with(tmp_path): queue = create_queue() max_epoch = 2 ckpt_path = tmp_path / "model.ckpt" - trainer = train_model(queue, max_epoch, ckpt_path) - assert trainer is not None + train_model(queue, max_epoch, ckpt_path) assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created" - ckpt_size = os.path.getsize(ckpt_path) assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)" - trainer = train_model(queue, max_epoch + 2, ckpt_path) - assert trainer is not None + train_model(queue, max_epoch + 2, ckpt_path) From cef9b31b595aaba66ceed94d73cacc47b7389b0f Mon Sep 17 00:00:00 2001 From: sudipto baral Date: Fri, 27 Jun 2025 03:49:13 -0400 Subject: [PATCH 5/5] Add independent flag to track checkpoint resumption. Signed-off-by: sudipto baral --- src/lightning/pytorch/loops/fit_loop.py | 5 ----- src/lightning/pytorch/loops/loop.py | 8 ++++++++ src/lightning/pytorch/loops/training_epoch_loop.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 065237c5dffcd..31d6724a043a3 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -204,11 +204,6 @@ def skip(self) -> bool: # so we cannot use it solely return self.done or self.trainer.limit_train_batches == 0 - @property - def _is_resuming(self) -> bool: - """Whether we're resuming training from a checkpoint.""" - return self._loaded_from_state_dict - def run(self) -> None: self.setup_data() if self.skip: diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index daad309cd75d4..22207caa387dc 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -23,6 +23,7 @@ class _Loop: def __init__(self, trainer: "pl.Trainer") -> None: self._restarting = False self._loaded_from_state_dict = False + self._resuming_from_checkpoint = False self.trainer = trainer @property @@ -30,6 +31,11 @@ def restarting(self) -> bool: """Whether the state of this loop was reloaded and it needs to restart.""" return self._restarting + @property + def is_resuming(self) -> bool: + """Whether we're resuming training from a checkpoint.""" + return self._resuming_from_checkpoint + @restarting.setter def restarting(self, restarting: bool) -> None: """Connects this loop's restarting value and its children.""" @@ -87,6 +93,7 @@ def load_state_dict( v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True self._loaded_from_state_dict = True + self._resuming_from_checkpoint = True def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: for k, v in self.__dict__.items(): @@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: def on_iteration_done(self) -> None: self._restarting = False self._loaded_from_state_dict = False + self._resuming_from_checkpoint = False self.reset_restart_stage() diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index eba38d0728091..6f748779a9b3f 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -239,9 +239,9 @@ def on_run_start(self, data_fetcher: _DataFetcher) -> None: # `iter()` was called once in `FitLoop.setup_data()` already # Only call `iter()` if all following cases: # 1. Not restarting - # 2. Not resuming from checkpoint (not _is_resuming) + # 2. Not resuming from checkpoint (not is_resuming) # 3. Past first epoch (current_epoch > 0) - if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop._is_resuming) and not self.restarting: + if (self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming) and not self.restarting: iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching