-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Using ModelCheckpoint
with every_n_train_steps
define and a metric logged at the end of the validation steps make the checkpoint logging the wrong model.
While the intent is to log the model with the best metric on validation, the checkpointing look out for the metric before the actual metric logging.
For the 1st epoch, a warning about metric not found is shown (``ModelCheckpoint(monitor='auroc')could not find the monitored key in the returned metrics: ['epoch', 'step']. HINT: Did you call
log('auroc', value)` in the `LightningModule`?`).
Then at the next checkpointing, the saved model is not the one corresponding to the best metric but the one used at the moment of the model checkpointing (model before the next validation step).
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
from tempfile import TemporaryDirectory
import mlflow
import torch
from lightning import Trainer, LightningModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from torch.utils.data import Dataset, DataLoader
class FakeDataset(Dataset):
def __init__(self):
self.data = [torch.randn(3) for _ in range(4)]
self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)]
def __len__(self):
return 4
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class SimpleModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(3, 1)
self.validation_aurocs = [
1 - torch.tensor(i/10)
for i in range(1, 3) # Simulating 2 validation epochs
]
self.validation_aurocs += [torch.tensor(1.0).to(dtype=torch.float32)] # make sure last epoch is the bets one
def training_step(self, batch, batch_idx):
out = self.layer(batch[0])
return torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
def validation_step(self, batch, batch_idx):
return
def on_validation_epoch_end(self) -> None:
self.log(
"auroc",
self.validation_aurocs[self.current_epoch],
logger=True,
)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
def create_and_set_experiment(name: str) -> mlflow.entities.Experiment:
existing_experiment = mlflow.get_experiment_by_name(name)
if existing_experiment is None:
mlflow.create_experiment(name)
return mlflow.set_experiment(experiment_name=name)
mlflow.set_tracking_uri("/storage/ml/mlruns")
mlflow_experiment = create_and_set_experiment(
name="lightning_checkpoint_issue"
)
with TemporaryDirectory() as tmpdir:
with mlflow.start_run(experiment_id=mlflow_experiment.experiment_id, run_name="log_issue") as active_run:
dataset = FakeDataset()
train_dataloader = DataLoader(dataset, batch_size=2)
val_dataloader = DataLoader(dataset, batch_size=2)
model = SimpleModule()
trainer = Trainer(
max_epochs=3,
accelerator="gpu",
devices=1,
logger=MLFlowLogger(
run_id=active_run.info.run_id,
tracking_uri="/storage/ml/mlruns",
log_model=True,
),
callbacks=[
ModelCheckpoint(
save_top_k=1,
monitor="auroc",
dirpath=tmpdir,
mode="max",
save_last=False,
every_n_train_steps=2,
train_time_interval=None,
every_n_epochs=0,
save_on_train_epoch_end=False,
save_weights_only=True,
)
],
check_val_every_n_epoch=1,
log_every_n_steps=1,
deterministic="warn",
num_sanity_val_steps=0,
)
trainer.fit(model, train_dataloader, val_dataloader)
In this example, the logged model should be `epoch=2-step=6' but it is `epoch=1-step=4'.
Error messages and logs
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 4060 Ti
- NVIDIA GeForce GTX 1080 Ti
- available: True
- version: 12.6 - Lightning:
- lightning: 2.5.0
- lightning-utilities: 0.14.3
- pytorch-lightning: 2.5.1.post0
- torch: 2.7.1
- torchmetrics: 1.7.3 - Packages:
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.12.13
- aiosignal: 1.3.2
- alembic: 1.16.2
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.9.0
- async-timeout: 5.0.1
- attrs: 25.3.0
- autocommand: 2.2.2
- backports.tarfile: 1.2.0
- blinker: 1.9.0
- cachetools: 5.5.2
- certifi: 2025.6.15
- charset-normalizer: 3.4.2
- click: 8.2.1
- cloudpickle: 3.1.1
- contourpy: 1.3.2
- cycler: 0.12.1
- databricks-sdk: 0.57.0
- docker: 7.1.0
- exceptiongroup: 1.3.0
- fastapi: 0.115.13
- filelock: 3.18.0
- flask: 3.1.1
- fonttools: 4.58.4
- frozenlist: 1.7.0
- fsspec: 2025.5.1
- gitdb: 4.0.12
- gitpython: 3.1.44
- google-auth: 2.40.3
- graphene: 3.4.3
- graphql-core: 3.2.6
- graphql-relay: 3.2.0
- greenlet: 3.2.3
- gunicorn: 23.0.0
- h11: 0.16.0
- idna: 3.10
- importlib-metadata: 8.7.0
- inflect: 7.3.1
- itsdangerous: 2.2.0
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jinja2: 3.1.6
- joblib: 1.5.1
- kiwisolver: 1.4.8
- lightning: 2.5.0
- lightning-utilities: 0.14.3
- mako: 1.3.10
- markupsafe: 3.0.2
- matplotlib: 3.10.3
- mlflow: 3.1.0
- mlflow-skinny: 3.1.0
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.5.0
- networkx: 3.4.2
- numpy: 2.2.6
- nvidia-cublas-cu12: 12.6.4.1
- nvidia-cuda-cupti-cu12: 12.6.80
- nvidia-cuda-nvrtc-cu12: 12.6.77
- nvidia-cuda-runtime-cu12: 12.6.77
- nvidia-cudnn-cu12: 9.5.1.17
- nvidia-cufft-cu12: 11.3.0.4
- nvidia-cufile-cu12: 1.11.1.6
- nvidia-curand-cu12: 10.3.7.77
- nvidia-cusolver-cu12: 11.7.1.2
- nvidia-cusparse-cu12: 12.5.4.2
- nvidia-cusparselt-cu12: 0.6.3
- nvidia-nccl-cu12: 2.26.2
- nvidia-nvjitlink-cu12: 12.6.85
- nvidia-nvtx-cu12: 12.6.77
- omegaconf: 2.3.0
- opentelemetry-api: 1.34.1
- opentelemetry-sdk: 1.34.1
- opentelemetry-semantic-conventions: 0.55b1
- packaging: 24.2
- pandas: 2.3.0
- pillow: 11.2.1
- platformdirs: 4.2.2
- propcache: 0.3.2
- protobuf: 6.31.1
- pyarrow: 20.0.0
- pyasn1: 0.6.1
- pyasn1-modules: 0.4.2
- pydantic: 2.11.7
- pydantic-core: 2.33.2
- pyparsing: 3.2.3
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.5.1.post0
- pytz: 2025.2
- pyyaml: 6.0.2
- requests: 2.32.4
- rsa: 4.9.1
- scikit-learn: 1.7.0
- scipy: 1.15.3
- setuptools: 80.9.0
- six: 1.17.0
- smmap: 5.0.2
- sniffio: 1.3.1
- sqlalchemy: 2.0.41
- sqlparse: 0.5.3
- starlette: 0.46.2
- sympy: 1.14.0
- tabulate: 0.9.0
- threadpoolctl: 3.6.0
- tomli: 2.2.1
- torch: 2.7.1
- torchmetrics: 1.7.3
- tqdm: 4.67.1
- triton: 3.3.1
- typeguard: 4.3.0
- typing-extensions: 4.14.0
- typing-inspection: 0.4.1
- tzdata: 2025.2
- urllib3: 2.5.0
- uvicorn: 0.34.3
- werkzeug: 3.1.3
- wheel: 0.45.1
- yarl: 1.20.1
- zipp: 3.23.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 6.8.0-59-generic
- version: DDP support on Jupyter Notebook #61~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 15 17:03:15 UTC 2
More info
No response
cc @lantiga