Skip to content

When checkpointing with a step interval on a validation metric, the checkpointing is done before the validation computationstep #20919

@Yann-CV

Description

@Yann-CV

Bug description

Using ModelCheckpoint with every_n_train_steps define and a metric logged at the end of the validation steps make the checkpoint logging the wrong model.

While the intent is to log the model with the best metric on validation, the checkpointing look out for the metric before the actual metric logging.

For the 1st epoch, a warning about metric not found is shown (``ModelCheckpoint(monitor='auroc')could not find the monitored key in the returned metrics: ['epoch', 'step']. HINT: Did you calllog('auroc', value)` in the `LightningModule`?`).

Then at the next checkpointing, the saved model is not the one corresponding to the best metric but the one used at the moment of the model checkpointing (model before the next validation step).

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

from tempfile import TemporaryDirectory

import mlflow
import torch
from lightning import Trainer, LightningModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from torch.utils.data import Dataset, DataLoader


class FakeDataset(Dataset):
    def __init__(self):
        self.data = [torch.randn(3) for _ in range(4)]
        self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)]

    def __len__(self):
        return 4

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class SimpleModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
        self.validation_aurocs = [
            1 - torch.tensor(i/10)
            for i in range(1, 3)  # Simulating 2 validation epochs
        ]
        self.validation_aurocs += [torch.tensor(1.0).to(dtype=torch.float32)] # make sure last epoch is the bets one

    def training_step(self, batch, batch_idx):
        out = self.layer(batch[0])
        return torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())

    def validation_step(self, batch, batch_idx):
        return

    def on_validation_epoch_end(self) -> None:
        self.log(
            "auroc",
            self.validation_aurocs[self.current_epoch],
            logger=True,
        )

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)


def create_and_set_experiment(name: str) -> mlflow.entities.Experiment:
    existing_experiment = mlflow.get_experiment_by_name(name)
    if existing_experiment is None:
        mlflow.create_experiment(name)
    return mlflow.set_experiment(experiment_name=name)


mlflow.set_tracking_uri("/storage/ml/mlruns")
mlflow_experiment = create_and_set_experiment(
    name="lightning_checkpoint_issue"
)

with TemporaryDirectory() as tmpdir:
    with mlflow.start_run(experiment_id=mlflow_experiment.experiment_id, run_name="log_issue") as active_run:
        dataset = FakeDataset()
        train_dataloader = DataLoader(dataset, batch_size=2)
        val_dataloader = DataLoader(dataset, batch_size=2)
        model = SimpleModule()
        trainer = Trainer(
            max_epochs=3,
            accelerator="gpu",
            devices=1,
            logger=MLFlowLogger(
                run_id=active_run.info.run_id,
                tracking_uri="/storage/ml/mlruns",
                log_model=True,
            ),
            callbacks=[
                ModelCheckpoint(
                    save_top_k=1,
                    monitor="auroc",
                    dirpath=tmpdir,
                    mode="max",
                    save_last=False,
                    every_n_train_steps=2,
                    train_time_interval=None,
                    every_n_epochs=0,
                    save_on_train_epoch_end=False,
                    save_weights_only=True,
                )
            ],
            check_val_every_n_epoch=1,
            log_every_n_steps=1,
            deterministic="warn",
            num_sanity_val_steps=0,
        )
        trainer.fit(model, train_dataloader, val_dataloader)


In this example, the logged model should be `epoch=2-step=6' but it is `epoch=1-step=4'.

Error messages and logs

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 4060 Ti
    - NVIDIA GeForce GTX 1080 Ti
    - available: True
    - version: 12.6
  • Lightning:
    - lightning: 2.5.0
    - lightning-utilities: 0.14.3
    - pytorch-lightning: 2.5.1.post0
    - torch: 2.7.1
    - torchmetrics: 1.7.3
  • Packages:
    - aiohappyeyeballs: 2.6.1
    - aiohttp: 3.12.13
    - aiosignal: 1.3.2
    - alembic: 1.16.2
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.9.0
    - async-timeout: 5.0.1
    - attrs: 25.3.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - blinker: 1.9.0
    - cachetools: 5.5.2
    - certifi: 2025.6.15
    - charset-normalizer: 3.4.2
    - click: 8.2.1
    - cloudpickle: 3.1.1
    - contourpy: 1.3.2
    - cycler: 0.12.1
    - databricks-sdk: 0.57.0
    - docker: 7.1.0
    - exceptiongroup: 1.3.0
    - fastapi: 0.115.13
    - filelock: 3.18.0
    - flask: 3.1.1
    - fonttools: 4.58.4
    - frozenlist: 1.7.0
    - fsspec: 2025.5.1
    - gitdb: 4.0.12
    - gitpython: 3.1.44
    - google-auth: 2.40.3
    - graphene: 3.4.3
    - graphql-core: 3.2.6
    - graphql-relay: 3.2.0
    - greenlet: 3.2.3
    - gunicorn: 23.0.0
    - h11: 0.16.0
    - idna: 3.10
    - importlib-metadata: 8.7.0
    - inflect: 7.3.1
    - itsdangerous: 2.2.0
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jinja2: 3.1.6
    - joblib: 1.5.1
    - kiwisolver: 1.4.8
    - lightning: 2.5.0
    - lightning-utilities: 0.14.3
    - mako: 1.3.10
    - markupsafe: 3.0.2
    - matplotlib: 3.10.3
    - mlflow: 3.1.0
    - mlflow-skinny: 3.1.0
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.5.0
    - networkx: 3.4.2
    - numpy: 2.2.6
    - nvidia-cublas-cu12: 12.6.4.1
    - nvidia-cuda-cupti-cu12: 12.6.80
    - nvidia-cuda-nvrtc-cu12: 12.6.77
    - nvidia-cuda-runtime-cu12: 12.6.77
    - nvidia-cudnn-cu12: 9.5.1.17
    - nvidia-cufft-cu12: 11.3.0.4
    - nvidia-cufile-cu12: 1.11.1.6
    - nvidia-curand-cu12: 10.3.7.77
    - nvidia-cusolver-cu12: 11.7.1.2
    - nvidia-cusparse-cu12: 12.5.4.2
    - nvidia-cusparselt-cu12: 0.6.3
    - nvidia-nccl-cu12: 2.26.2
    - nvidia-nvjitlink-cu12: 12.6.85
    - nvidia-nvtx-cu12: 12.6.77
    - omegaconf: 2.3.0
    - opentelemetry-api: 1.34.1
    - opentelemetry-sdk: 1.34.1
    - opentelemetry-semantic-conventions: 0.55b1
    - packaging: 24.2
    - pandas: 2.3.0
    - pillow: 11.2.1
    - platformdirs: 4.2.2
    - propcache: 0.3.2
    - protobuf: 6.31.1
    - pyarrow: 20.0.0
    - pyasn1: 0.6.1
    - pyasn1-modules: 0.4.2
    - pydantic: 2.11.7
    - pydantic-core: 2.33.2
    - pyparsing: 3.2.3
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.5.1.post0
    - pytz: 2025.2
    - pyyaml: 6.0.2
    - requests: 2.32.4
    - rsa: 4.9.1
    - scikit-learn: 1.7.0
    - scipy: 1.15.3
    - setuptools: 80.9.0
    - six: 1.17.0
    - smmap: 5.0.2
    - sniffio: 1.3.1
    - sqlalchemy: 2.0.41
    - sqlparse: 0.5.3
    - starlette: 0.46.2
    - sympy: 1.14.0
    - tabulate: 0.9.0
    - threadpoolctl: 3.6.0
    - tomli: 2.2.1
    - torch: 2.7.1
    - torchmetrics: 1.7.3
    - tqdm: 4.67.1
    - triton: 3.3.1
    - typeguard: 4.3.0
    - typing-extensions: 4.14.0
    - typing-inspection: 0.4.1
    - tzdata: 2025.2
    - urllib3: 2.5.0
    - uvicorn: 0.34.3
    - werkzeug: 3.1.3
    - wheel: 0.45.1
    - yarl: 1.20.1
    - zipp: 3.23.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 6.8.0-59-generic
    - version: DDP support on Jupyter Notebook #61~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 15 17:03:15 UTC 2

More info

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions