How to ensure objects saved as model attributes are saved in the checkpoint file? #8841
-
Say I have a lightning model Is it possible to somehow ensure that these model attributes get saved in the checkpoint file and properly restored when loading the model from checkpoint? Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Dear @KirillShmilovich. You could use the LightningModule class MyLightningModel(LightningModule):
def on_save_checkpoint(self):
return {"my_object": self.my_object}
def on_load_checkpoint(self, state_dict):
self.my_object = state_dict["my_object"] However, pickling objets isn't always the best approach. A slightly better approach is class MyLightningModel(LightningModule):
def on_save_checkpoint(self):
return {"my_object_state_dict": self.my_object.state_dict()}
def on_load_checkpoint(self, state_dict):
self.my_object = my_object_cls.from_state_dict(state_dict["my_object_state_dict"]) |
Beta Was this translation helpful? Give feedback.
Dear @KirillShmilovich.
You could use the LightningModule
on_save_checkpoint
andon_load_checkpoint
hooks.However, pickling objets isn't always the best approach. A slightly better approach is