From 93831ba38dfad2b72aac7608fe80f2f84ec01d6e Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 19 Feb 2023 15:57:30 +0900 Subject: [PATCH] Show version mismatch warning --- d3rlpy/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/d3rlpy/base.py b/d3rlpy/base.py index 16169f37..21f85718 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -8,6 +8,7 @@ import torch from gym.spaces import Discrete +from ._version import __version__ from .constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace from .dataset import DatasetInfo, ReplayBuffer, Shape from .logger import LOG, D3RLPyLogger @@ -144,13 +145,23 @@ def dump_learnable(algo: "LearnableBase", fname: str) -> None: action_size=algo.impl.action_size, config=algo.config, ) - obj = {"torch": torch_bytes.getvalue(), "config": config.serialize()} + obj = { + "torch": torch_bytes.getvalue(), + "config": config.serialize(), + "version": __version__, + } pickle.dump(obj, f) def load_learnable(fname: str, device: DeviceArg = None) -> "LearnableBase": with open(fname, "rb") as f: obj = pickle.load(f) + if obj["version"] != __version__: + LOG.warning( + "There might be incompatibility because of version mismatch.", + current_version=__version__, + saved_version=obj["version"], + ) config = LearnableConfigWithShape.deserialize(obj["config"]) algo = config.create(device) assert algo.impl