diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 004adae90..d0eb0856c 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -136,6 +136,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: :return: """ if copy: + if hasattr(th, "backends") and th.backends.mps.is_built(): + return th.tensor(array, dtype=th.float32, device=self.device) return th.tensor(array, device=self.device) return th.as_tensor(array, device=self.device) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 4d99313ea..d0d48df0c 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -81,7 +81,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: state = state.astype(np.int32) # The internal state is the binary representation of the # observed one - return int(sum(state[i] * 2**i for i in range(len(state)))) + return int(sum(int(state[i]) * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index a5a5476b9..562ff132a 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -486,6 +486,8 @@ def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.devi if isinstance(obs, np.ndarray): return th.as_tensor(obs, device=device) elif isinstance(obs, dict): + if hasattr(th, "backends") and th.backends.mps.is_built(): + return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()} return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} else: raise Exception(f"Unrecognized type of observation {type(obs)}") @@ -526,6 +528,7 @@ def get_available_accelerator() -> str: """ if hasattr(th, "backends") and th.backends.mps.is_built(): # MacOS Metal GPU + th.set_default_dtype(th.float32) return "mps" elif th.cuda.is_available(): return "cuda" diff --git a/tests/test_spaces.py b/tests/test_spaces.py index cd38e1ecd..102c0ef8f 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -4,6 +4,7 @@ import gymnasium as gym import numpy as np import pytest +import torch as th from gymnasium import spaces from gymnasium.spaces.space import Space @@ -151,6 +152,8 @@ def test_discrete_obs_space(model_class, env): ], ) def test_float64_action_space(model_class, obs_space, action_space): + if hasattr(th, "backends") and th.backends.mps.is_built(): + pytest.skip("MPS framework doesn't support float64") env = DummyEnv(obs_space, action_space) env = gym.wrappers.TimeLimit(env, max_episode_steps=200) if isinstance(env.observation_space, spaces.Dict):