@@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
871871 hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
872872 assert hparams_path .is_file ()
873873 hparams = yaml .safe_load (hparams_path .read_text ())
874- expected = {
875- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
876- "optimizer" : "torch.optim.Adam" ,
877- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
878- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
879- }
880- assert hparams == expected
874+
875+ expected_keys = ["_instantiator" , "activation" , "optimizer" , "scheduler" ]
876+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
877+ expected_activation = "torch.nn.LeakyReLU"
878+ expected_optimizer = "torch.optim.Adam"
879+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
880+
881+ assert sorted (hparams .keys ()) == expected_keys
882+ assert hparams ["_instantiator" ] == expected_instantiator
883+ assert hparams ["activation" ]["class_path" ] == expected_activation
884+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
885+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
881886
882887 checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
883888 assert checkpoint_path .is_file ()
884- ckpt = torch .load (checkpoint_path , weights_only = True )
885- assert ckpt ["hyper_parameters" ] == expected
889+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
890+ assert sorted (hparams .keys ()) == expected_keys
891+ assert hparams ["_instantiator" ] == expected_instantiator
892+ assert hparams ["activation" ]["class_path" ] == expected_activation
893+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
894+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
886895
887896 model = TestModelSaveHparams .load_from_checkpoint (checkpoint_path )
888897 assert isinstance (model , TestModelSaveHparams )
@@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
898907 cli = LightningCLI (TestModelSaveHparams , run = False , auto_configure_optimizers = False , subclass_mode_model = True )
899908 cli .trainer .fit (cli .model )
900909
901- expected = {
902- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
903- "_class_path" : f"{ __name__ } .TestModelSaveHparams" ,
904- "optimizer" : "torch.optim.Adam" ,
905- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
906- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
907- }
910+ expected_keys = ["_class_path" , "_instantiator" , "activation" , "optimizer" , "scheduler" ]
911+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
912+ expected_class_path = f"{ __name__ } .TestModelSaveHparams"
913+ expected_activation = "torch.nn.LeakyReLU"
914+ expected_optimizer = "torch.optim.Adam"
915+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
908916
909917 checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
910918 assert checkpoint_path .is_file ()
911- ckpt = torch .load (checkpoint_path , weights_only = True )
912- assert ckpt ["hyper_parameters" ] == expected
919+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
920+
921+ assert sorted (hparams .keys ()) == expected_keys
922+ assert hparams ["_instantiator" ] == expected_instantiator
923+ assert hparams ["_class_path" ] == expected_class_path
924+ assert hparams ["activation" ]["class_path" ] == expected_activation
925+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
926+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
913927
914928 model = LightningModule .load_from_checkpoint (checkpoint_path )
915929 assert isinstance (model , TestModelSaveHparams )
0 commit comments