Skip to content
Discussion options

You must be logged in to vote

Dear @KirillShmilovich.

You could use the LightningModule on_save_checkpoint and on_load_checkpoint hooks.

class MyLightningModel(LightningModule):

    def on_save_checkpoint(self):
        return {"my_object": self.my_object}

    def on_load_checkpoint(self, state_dict):
        self.my_object = state_dict["my_object"]

However, pickling objets isn't always the best approach. A slightly better approach is

class MyLightningModel(LightningModule):

    def on_save_checkpoint(self):
        return {"my_object_state_dict": self.my_object.state_dict()}

    def on_load_checkpoint(self, state_dict):
        self.my_object = my_object_cls.from_state_dict(state_dict["my_object_state_dict"])

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by KirillShmilovich
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment