Skip to content

Commit

Permalink
Show version mismatch warning
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 19, 2023
1 parent 932280b commit 93831ba
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion d3rlpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from gym.spaces import Discrete

from ._version import __version__
from .constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from .dataset import DatasetInfo, ReplayBuffer, Shape
from .logger import LOG, D3RLPyLogger
Expand Down Expand Up @@ -144,13 +145,23 @@ def dump_learnable(algo: "LearnableBase", fname: str) -> None:
action_size=algo.impl.action_size,
config=algo.config,
)
obj = {"torch": torch_bytes.getvalue(), "config": config.serialize()}
obj = {
"torch": torch_bytes.getvalue(),
"config": config.serialize(),
"version": __version__,
}
pickle.dump(obj, f)


def load_learnable(fname: str, device: DeviceArg = None) -> "LearnableBase":
with open(fname, "rb") as f:
obj = pickle.load(f)
if obj["version"] != __version__:
LOG.warning(
"There might be incompatibility because of version mismatch.",
current_version=__version__,
saved_version=obj["version"],
)
config = LearnableConfigWithShape.deserialize(obj["config"])
algo = config.create(device)
assert algo.impl
Expand Down

0 comments on commit 93831ba

Please sign in to comment.