diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 771c3a639..21c124f29 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -428,14 +428,14 @@ def to_params(self) -> Dict[str, Any]: "embedding_length": self.__embedding_length, } - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + def state_dict(self, *args, **kwargs): # when loading the old versions from pickle, the embeddings might not be added as pytorch module. # we do this delayed, when the weights are collected (e.g. for saving), as doing this earlier might # lead to issues while loading (trying to load weights that weren't stored as python weights and therefore # not finding them) if list(self.modules()) == [self]: self.embedding = self.embedding - return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + return super().state_dict(*args, **kwargs) @register_embeddings