diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst index 8bd9aa8ea..0917a1505 100644 --- a/docs/guide/save_format.rst +++ b/docs/guide/save_format.rst @@ -24,6 +24,11 @@ A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files are stored under a single .zip archive. +Note that if you use unsafe objects in your torch model, ``torch.load()`` will raise an unpickling error. You can +use the "weights_only" argument to adjust whether or not to load unsafe objects using Pickle, but it will issue +a warning if set to False. (e.g.: if learning_rate_schedule contains the scalar np.pi, it will raise an error without +the "weights_only" argument set to False) + Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded string in the JSON file, along with some information that was stored in the serialization. This allows inspecting stored objects without deserializing the object itself. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c1560201c..1d069b3ae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -44,6 +44,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger) +- This can be overriden using the ``weights_only`` boolean argument in the ``load()`` method in sb3, which will be passed to ``torch.load()`` (@markscsmith) Bug Fixes: ^^^^^^^^^^ @@ -1593,4 +1594,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah +@marekm4 @stagoverflow @rushitnshah @markscsmith diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e6c7d3cfc..ad250f5c4 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -575,6 +575,7 @@ def set_parameters( load_path_or_dict: Union[str, TensorDict], exact_match: bool = True, device: Union[th.device, str] = "auto", + weights_only: bool = True, ) -> None: """ Load parameters from a given zip-file or a nested dictionary containing parameters for @@ -587,12 +588,15 @@ def set_parameters( module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters. :param device: Device on which the code should run. + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html """ params = {} if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: - _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + _, params, _ = load_from_zip_file(load_path_or_dict, device=device, weights_only=weights_only) # Keep track which objects were updated. # `_get_torch_save_params` returns [params, other_pytorch_variables]. @@ -647,6 +651,7 @@ def load( # noqa: C901 custom_objects: Optional[Dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, + weights_only: bool = True, **kwargs, ) -> SelfBaseAlgorithm: """ @@ -670,6 +675,9 @@ def load( # noqa: C901 :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597 + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ @@ -678,10 +686,7 @@ def load( # noqa: C901 get_system_info() data, params, pytorch_variables = load_from_zip_file( - path, - device=device, - custom_objects=custom_objects, - print_system_info=print_system_info, + path, device=device, custom_objects=custom_objects, print_system_info=print_system_info, weights_only=weights_only ) assert data is not None, "No data found in the saved file" diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 2d8652006..2601c358f 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -380,6 +380,7 @@ def load_from_zip_file( device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, + weights_only: bool = True, ) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive @@ -397,6 +398,9 @@ def load_from_zip_file( :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages :param print_system_info: Whether to print or not the system info about the saved model. + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ @@ -426,7 +430,10 @@ def load_from_zip_file( "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.", UserWarning, ) - + if weights_only is False: + warnings.warn( + "Unpickling unsafe objects! Loading full state_dict. See pytorch docs on torch.load for more info." + ) if "data" in namelist and load_data: # Load class parameters that are stored # with either JSON or pickle (not PyTorch variables). @@ -447,7 +454,7 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - th_object = th.load(file_content, map_location=device, weights_only=True) + th_object = th.load(file_content, map_location=device, weights_only=weights_only) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e7123e984..4ca5ed563 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,8 +1,10 @@ import base64 import io import json +import math import os import pathlib +import pickle import tempfile import warnings import zipfile @@ -739,6 +741,51 @@ def test_load_invalid_object(tmp_path): assert len(record) == 0 +def test_load_torch_weights_only(tmp_path): + # Test loading only the torch weights + path = str(tmp_path / "ppo_pendulum.zip") + model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0) + model.save(path) + # Load with custom object, no warnings + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=False) + assert len(record) == 1 + + # Load only the weights from a valid model + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=True) + assert len(record) == 0 + + model = PPO( + policy="MlpPolicy", + env="Pendulum-v1", + learning_rate=math.sin(1), + ) + model.save(path) + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=True) + assert len(record) == 0 + + # Causes pickle error due to numpy scalars in the learning rate schedule: + # _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` + # will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a + # trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar + + model = PPO( + policy="MlpPolicy", + env="Pendulum-v1", + learning_rate=lambda _: np.sin(1), + ) + model.save(path) + + with pytest.raises(pickle.UnpicklingError) as record: + model.load(path, weights_only=True) + + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=False) + assert len(record) == 1 + + def test_dqn_target_update_interval(tmp_path): # `target_update_interval` should not change when reloading the model. See GH Issue #1373. env = make_vec_env(env_id="CartPole-v1", n_envs=2)