Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while merging hparams when using LightningCLI and YAML #20182

Open
cgebbe opened this issue Aug 9, 2024 · 5 comments · May be fixed by #20221
Open

Error while merging hparams when using LightningCLI and YAML #20182

cgebbe opened this issue Aug 9, 2024 · 5 comments · May be fixed by #20221
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@cgebbe
Copy link

cgebbe commented Aug 9, 2024

Bug description

The minimal example below throws the error RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

I though this was supposed to work. Would really appreciate workaround tips (that also work with checkpointing) or a fix.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

Run the code below using e.g. python main.py --config config.yaml

# config.yaml
data:
  class_path: MNISTDataModule

model:
  class_path: LitAutoEncoder
# main.py
import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
from torch.utils.data import random_split, DataLoader
import torch
from torchvision import transforms


class LitAutoEncoder(L.LightningModule):
    """From https://lightning.ai/docs/pytorch/stable/starter/introduction.html"""

    def __init__(self, dim: int = 64):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, dim), nn.ReLU(), nn.Linear(dim, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, dim), nn.ReLU(), nn.Linear(dim, 28 * 28)
        )

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


class MNISTDataModule(L.LightningDataModule):
    """From https://lightning.ai/docs/pytorch/stable/data/datamodule.html"""

    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.save_hyperparameters()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

        if stage == "predict":
            self.mnist_predict = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)


cli.LightningCLI()

Error messages and logs

RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A10G
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.4
    - pytorch-lightning: 2.3.3
    - torch: 2.3.1
    - torchdata: 0.7.1
    - torchmetrics: 1.0.3
    - torchsummary: 1.5.1
    - torchvision: 0.18.1
  • Packages:
    - absl-py: 2.1.0
    - affine: 2.4.0
    - aiobotocore: 2.13.0
    - aiohttp: 3.9.5
    - aioitertools: 0.11.0
    - aiosignal: 1.3.1
    - albucore: 0.0.12
    - albumentations: 1.4.11
    - altair: 5.3.0
    - annotated-types: 0.7.0
    - ansi2html: 1.9.1
    - ansicolors: 1.1.8
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - appdirs: 1.4.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asciitree: 0.3.3
    - assertpy: 1.1
    - asttokens: 2.4.1
    - astunparse: 1.6.3
    - async-lru: 2.0.4
    - attrs: 23.2.0
    - automation-api-gateway-client: 12.0.165
    - az-annotation-io: 0.18.8
    - az-cp-aws-utils: 1.2.5
    - az-cp-datadictionary-definitions: 0.2024.119.post1
    - az-cp-drclib: 0.1.0
    - az-cp-holoviews-compressed-rgb: 0.0.1rc3
    - az-cp-imagekit-ventana-bif: 0.2.2
    - az-cp-logging: 0.15.0
    - az-cp-ooportal: 1.4.7
    - az-cp-pathviz: 0.0.1rc7
    - az-cp-pita: 1.10.1
    - az-cp-predictino-container: 3.10.3
    - az-drc2polygons: 2.2.2
    - az-git-utils: 0.14.0
    - babel: 2.15.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - blessed: 1.20.0
    - blinker: 1.8.2
    - bokeh: 3.4.1
    - boto3: 1.34.106
    - botocore: 1.34.106
    - bpython: 0.24
    - braceexpand: 0.1.7
    - cachetools: 5.3.3
    - cerberus: 1.3.5
    - certifi: 2024.2.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - click-plugins: 1.1.1
    - cligj: 0.7.2
    - cloudpickle: 3.0.0
    - cmake: 3.25.0
    - colorama: 0.4.6
    - colorcet: 3.1.0
    - comm: 0.2.2
    - contourpy: 1.2.1
    - coreg-tools: 0.3.0
    - cubinlinker-cu11: 0.3.0.post2
    - cuda-python: 11.8.3
    - cudf-cu11: 24.6.1
    - cuml-cu11: 24.6.1
    - cupy-cuda11x: 13.2.0
    - curtsies: 0.4.2
    - cwcwidth: 0.1.9
    - cycler: 0.12.1
    - dash: 2.17.0
    - dash-core-components: 2.0.0
    - dash-html-components: 2.0.0
    - dash-table: 5.0.0
    - dask: 2024.5.1
    - dask-cuda: 24.6.0
    - dask-cudf-cu11: 24.6.1
    - dask-expr: 1.1.1
    - datashader: 0.16.2
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - definiens-autocli: 4.5.3
    - definiens-ia-algorithms: 3.39.1
    - definiens-imagekit: 5.6.6
    - definiens-parallel: 1.3.0
    - defusedxml: 0.7.1
    - deprecated: 1.2.14
    - distlib: 0.3.8
    - distributed: 2024.5.1
    - distributed-ucxx-cu11: 0.38.0
    - dm-tree: 0.1.8
    - dnspython: 2.6.1
    - docker-pycreds: 0.4.0
    - docstring-parser: 0.16
    - ec2-metadata: 2.13.0
    - editor: 1.6.6
    - entrypoints: 0.4
    - eval-type-backport: 0.2.0
    - executing: 2.0.1
    - fancycompleter: 0.9.1
    - fasteners: 0.19
    - fastjsonschema: 2.19.1
    - fastrlock: 0.8.2
    - filelock: 3.14.0
    - fiona: 1.9.6
    - flask: 3.0.3
    - flatbuffers: 24.3.25
    - fonttools: 4.52.4
    - fqdn: 1.5.1
    - frozendict: 2.4.4
    - frozenlist: 1.4.1
    - fsspec: 2024.5.0
    - fvcore: 0.1.5.post20221221
    - gast: 0.6.0
    - geopandas: 1.0.1
    - geopolars: 0.1.0a4
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - google-auth: 2.32.0
    - google-auth-oauthlib: 1.0.0
    - google-pasta: 0.2.0
    - greenlet: 3.0.3
    - grpcio: 1.64.1
    - h11: 0.14.0
    - h5py: 3.11.0
    - holoviews: 1.19.0
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - huggingface-hub: 0.24.0
    - hvplot: 0.10.0
    - identify: 2.5.36
    - idna: 3.7
    - imagecodecs: 2024.1.1
    - imageio: 2.34.1
    - importlib-metadata: 7.1.0
    - iniconfig: 2.0.0
    - inquirer: 3.2.4
    - ioda: 0.19.1
    - ioda-readout-service-client: 0.0.5
    - ioda-result-service-client: 0.6.1
    - iopath: 0.1.10
    - ipp: 2021.4.0
    - ipykernel: 6.29.4
    - ipython: 8.24.0
    - ipytree: 0.2.2
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - itables: 2.1.0
    - itsdangerous: 2.2.0
    - jax: 0.4.30
    - jaxlib: 0.4.30
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonargparse: 4.29.0
    - jsonpointer: 2.4
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-bokeh: 4.0.5
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-dash: 0.4.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.1
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.1
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.2
    - jupyterlab-vim: 4.1.3
    - jupyterlab-widgets: 3.0.11
    - kaleido: 0.2.1
    - keras: 2.12.0
    - kiwisolver: 1.4.5
    - lazy-loader: 0.4
    - leb128: 1.0.7
    - libclang: 18.1.1
    - libucx-cu11: 1.15.0.post1
    - lightning: 2.4.0
    - lightning-utilities: 0.11.4
    - linkify-it-py: 2.0.3
    - lit: 15.0.7
    - litdata: 0.2.16
    - llvmlite: 0.42.0
    - locket: 1.0.0
    - loguru: 0.7.2
    - lxml: 4.9.4
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.1
    - mdurl: 0.1.2
    - mistune: 3.0.2
    - ml-dtypes: 0.4.0
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multipledispatch: 1.0.0
    - namex: 0.0.8
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - newrelic: 8.11.0
    - nodeenv: 1.9.0
    - notebook: 7.2.0
    - notebook-cgebbe: 0.1.0
    - notebook-shim: 0.2.4
    - nrai-wrapper: 2.19.339839
    - numba: 0.59.1
    - numcodecs: 0.12.1
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-dali-cuda120: 1.39.0
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvimgcodec-cu12: 0.2.0.7
    - nvidia-nvjitlink-cu12: 12.5.82
    - nvidia-nvtx-cu12: 12.1.105
    - nvtx: 0.2.10
    - oauthlib: 3.2.2
    - ome-types: 0.5.1.post1
    - omegaconf: 2.3.0
    - opencv-python-headless: 4.9.0.80
    - openslide-python: 1.3.1
    - opt-einsum: 3.3.0
    - optree: 0.12.1
    - outcome: 1.3.0.post0
    - overrides: 7.7.0
    - packaging: 24.0
    - pandas: 2.1.4
    - pandocfilters: 1.5.1
    - panel: 1.4.4
    - papermill: 2.6.0
    - param: 2.1.0
    - parso: 0.8.4
    - partd: 1.4.2
    - pdbpp: 0.10.3
    - pexpect: 4.9.0
    - pillow: 10.3.0
    - pip: 24.1.2
    - pivottablejs: 0.9.0
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 0.20.30
    - portal-test-utils: 8.100.0
    - portalocker: 2.10.1
    - pre-commit: 3.7.1
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.45
    - protobuf: 3.20.0
    - psutil: 5.9.8
    - psycopg2-binary: 2.9.9
    - ptpython: 3.0.27
    - ptxcompiler-cu11: 0.8.1.post1
    - ptyprocess: 0.7.0
    - pudb: 2024.1
    - pure-eval: 0.2.2
    - pyarrow: 16.1.0
    - pyasn1: 0.6.0
    - pyasn1-modules: 0.4.0
    - pycparser: 2.22
    - pyct: 0.5.0
    - pydantic: 2.7.3
    - pydantic-compat: 0.1.2
    - pydantic-core: 2.18.4
    - pydeck: 0.9.1
    - pygments: 2.18.0
    - pylibraft-cu11: 24.6.0
    - pymongo: 4.8.0
    - pynvml: 11.4.1
    - pyogrio: 0.9.0
    - pyotp: 2.9.0
    - pyparsing: 3.1.2
    - pypeln: 0.4.9
    - pyportal: 2.33.2
    - pyproj: 3.6.1
    - pyrepl: 0.9.0
    - pysocks: 1.7.1
    - pytest: 8.2.2
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - python-gitlab: 3.15.0
    - python-json-logger: 2.0.7
    - python-magic: 0.4.27
    - python-on-whales: 0.71.0
    - pytorch-lightning: 2.3.3
    - pytz: 2024.1
    - pyviz-comms: 3.0.2
    - pyxdg: 0.28
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - raft-dask-cu11: 24.6.0
    - rai: 2.4.1
    - rapids-dask-dependency: 24.6.0
    - rasterio: 1.3.10
    - ray: 2.31.0
    - rdi: 2.19.339839
    - readchar: 4.1.0
    - referencing: 0.35.1
    - requests: 2.32.2
    - requests-oauthlib: 2.0.0
    - requests-toolbelt: 1.0.0
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rmm-cu11: 24.6.0
    - rpds-py: 0.18.1
    - rsa: 4.9
    - rtree: 1.2.0
    - ruamel.yaml: 0.18.6
    - ruamel.yaml.clib: 0.2.8
    - ruff: 0.5.1
    - runs: 1.2.2
    - s3cmd: 2.4.0
    - s3fs: 2024.5.0
    - s3transfer: 0.10.1
    - safetensors: 0.4.3
    - scikit-image: 0.23.2
    - scikit-learn: 1.5.1
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - segment-anything: 1.0
    - selenium: 4.22.0
    - semantic-segmentation: 0.1.0
    - send2trash: 1.8.3
    - sentry-sdk: 2.6.0
    - setproctitle: 1.3.3
    - setuptools: 63.4.3
    - shapely: 2.0.4
    - shellingham: 1.5.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snuggs: 1.4.7
    - sortedcontainers: 2.4.0
    - soupsieve: 2.5
    - spatialpandas: 0.4.10
    - sqlalchemy: 2.0.31
    - stack-data: 0.6.3
    - stopit: 1.1.2
    - streamlit: 1.36.0
    - submitit: 1.5.1
    - sympy: 1.12
    - tabulate: 0.9.0
    - tblib: 3.0.0
    - tenacity: 8.3.0
    - tensorboard: 2.12.3
    - tensorboard-data-server: 0.7.2
    - tensorflow: 2.12.1
    - tensorflow-estimator: 2.12.0
    - tensorflow-io: 0.37.1
    - tensorflow-io-gcs-filesystem: 0.37.1
    - termcolor: 2.4.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tifffile: 2024.5.22
    - timm: 1.0.7
    - tinycss2: 1.3.0
    - toml: 0.10.2
    - tomli: 2.0.1
    - toolz: 0.12.1
    - torch: 2.3.1
    - torchdata: 0.7.1
    - torchmetrics: 1.0.3
    - torchsummary: 1.5.1
    - torchvision: 0.18.1
    - tornado: 6.4
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - treelite: 4.1.2
    - trio: 0.25.1
    - trio-websocket: 0.11.1
    - triton: 2.3.1
    - typer: 0.12.3
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uc-micro-py: 1.0.3
    - ucx-py-cu11: 0.38.0
    - ucxx-cu11: 0.38.0
    - uri-template: 1.3.0
    - urllib3: 2.0.7
    - urwid: 2.6.12
    - urwid-readline: 0.14
    - ventanamripy: 0.3
    - virtualenv: 20.26.2
    - wandb: 0.17.2
    - watchdog: 4.0.1
    - wcwidth: 0.2.13
    - webcolors: 1.13
    - webdataset: 0.2.93
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.3
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - wmctrl: 0.5
    - wrapt: 1.14.1
    - wsproto: 1.2.0
    - xarray: 2024.6.0
    - xformers: 0.0.27
    - xmod: 1.8.1
    - xsdata: 24.3.1
    - xyzservices: 2024.6.0
    - yacs: 0.1.8
    - yarl: 1.9.4
    - zarr: 2.18.2
    - zict: 3.0.0
    - zipp: 3.19.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-1066-aws
    - version: Trainer.fit() crashes if no checkpoint callback is provided #72~20.04.1-Ubuntu SMP Thu Jul 18 10:41:27 UTC 2024

More info

Relevant existing issues

@cgebbe cgebbe added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 9, 2024
@cgebbe
Copy link
Author

cgebbe commented Aug 9, 2024

FYI: I worked around it now by adding the following to my LightningModule. A bit brittle, but seems to work for now.

    def setup(self, stage: str):
        self._save_config()

    def _save_config(self):
        # P: Currently, we do not log all hyperparameters:
        #    lightning.CLI saves the config.yaml only locally, not well to W&B.
        #    self.save_hyperparameters() does not work well with YAML, see
        #    https://github.com/Lightning-AI/pytorch-lightning/issues/20182
        # S: Read saved local config and save it as hparams.
        #    You should disable any other `self.save_hyperparameters()`.
        # NOTE: This logs the RESOLVED config using YAML and CLI arguments.
        import torch.distributed as dist

        if dist.is_initialized() and dist.get_rank() != 0:
            # only save config with rank0
            return

        if self.trainer.fast_dev_run:
            # in fast_dev_run mode, loggers are replaced by DummyLogger
            return

        config_yaml_path = Path(self.logger.save_dir) / "config.yaml"
        assert config_yaml_path.exists()
        with open(config_yaml_path) as f:
            dct = yaml.safe_load(f)
        self.save_hyperparameters(dct)
        print("=== Saved config.yaml as hyperparameter")

@tshu-w
Copy link
Contributor

tshu-w commented Aug 16, 2024

cc @mauvilsa @awaelchli

@mauvilsa
Copy link
Contributor

I was not aware of this merging of parameters. I guess this happens always when both model and data do save_hyperparameters. This should be easy to fix by excluding the special keys _class_path and _instantiator. Though, maybe better to exclude all keys starting with _ since parameter names should not start with this, and thus "Force User Decisions To Best Practices".

@DucoG
Copy link

DucoG commented Aug 21, 2024

Ran into same issue.

Simple temporary workaround:
In either LightningModule or LightningDataModule call:

self.save_hyperparameters(ignore=['_class_path'])

similar issues: #9492

@mauvilsa
Copy link
Contributor

I created pull request #20221 to fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants