Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change LightningCLI tests to account for future fix in jsonargparse #20372

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
assert hparams == expected

expected_keys = ["_instantiator", "activation", "optimizer", "scheduler"]
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
expected_activation = "torch.nn.LeakyReLU"
expected_optimizer = "torch.optim.Adam"
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"

assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]
assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
Expand All @@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
cli.trainer.fit(cli.model)

expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"_class_path": f"{__name__}.TestModelSaveHparams",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
expected_keys = ["_class_path", "_instantiator", "activation", "optimizer", "scheduler"]
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
expected_class_path = f"{__name__}.TestModelSaveHparams"
expected_activation = "torch.nn.LeakyReLU"
expected_optimizer = "torch.optim.Adam"
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]

assert sorted(hparams.keys()) == expected_keys
assert hparams["_instantiator"] == expected_instantiator
assert hparams["_class_path"] == expected_class_path
assert hparams["activation"]["class_path"] == expected_activation
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler

model = LightningModule.load_from_checkpoint(checkpoint_path)
assert isinstance(model, TestModelSaveHparams)
Expand Down
Loading