Skip to content

Commit

Permalink
[Feature] Gymnasium 1.0 compatibility
Browse files Browse the repository at this point in the history
ghstack-source-id: 305b15271bd5b0cd7d9b55a9f8b2079bbc40950f
Pull Request resolved: #2473
  • Loading branch information
vmoens committed Oct 9, 2024
1 parent fac1f7a commit 08c00b1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ do
done

# For this version "gym[accept-rom-license]" is required.
for GYM_VERSION in '0.27' '0.28'
for GYM_VERSION in '0.27' '0.28' '0.29'
do
# Create a copy of the conda env and work with this
conda deactivate
Expand All @@ -140,7 +140,7 @@ conda deactivate
conda create --prefix ./cloned_env --clone ./env -y
conda activate ./cloned_env

pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U

$DIR/run_test.sh

Expand Down
95 changes: 70 additions & 25 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,9 +1622,9 @@ def register_gym(
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
"""Registers an environment in gym(nasium).
Expand Down Expand Up @@ -1811,9 +1811,9 @@ def _register_gym(
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gym
Expand All @@ -1836,9 +1836,9 @@ def _register_gym(
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
autoreset=bool(autoreset),
disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
apply_api_compatibility=bool(apply_api_compatibility),
)

@implement_for("gym", "0.25", "0.26", class_method=True)
Expand All @@ -1853,14 +1853,14 @@ def _register_gym( # noqa: F811
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gym

if apply_api_compatibility is not False:
if apply_api_compatibility is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
Expand Down Expand Up @@ -1901,14 +1901,14 @@ def _register_gym( # noqa: F811
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gym

if apply_api_compatibility is not False:
if apply_api_compatibility is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
Expand Down Expand Up @@ -1954,14 +1954,14 @@ def _register_gym( # noqa: F811
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gym

if apply_api_compatibility is not False:
if apply_api_compatibility is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
Expand All @@ -1973,7 +1973,7 @@ def _register_gym( # noqa: F811
"disable_env_checker", gym.__version__
)
)
if autoreset is not False:
if autoreset is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
)
Expand Down Expand Up @@ -2010,9 +2010,9 @@ def _register_gym( # noqa: F811
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gym
Expand All @@ -2028,11 +2028,11 @@ def _register_gym( # noqa: F811
"disable_env_checker", gym.__version__
)
)
if autoreset is not False:
if autoreset is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
)
if apply_api_compatibility is not False:
if apply_api_compatibility is not None:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
Expand All @@ -2056,7 +2056,7 @@ def _register_gym( # noqa: F811
max_episode_steps=max_episode_steps,
)

@implement_for("gymnasium", class_method=True)
@implement_for("gymnasium", None, "1.0", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
Expand All @@ -2068,9 +2068,9 @@ def _register_gym( # noqa: F811
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
autoreset: bool | None = None,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gymnasium
Expand All @@ -2094,9 +2094,54 @@ def _register_gym( # noqa: F811
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
autoreset=bool(autoreset),
disable_env_checker=disable_env_checker,
apply_api_compatibility=bool(apply_api_compatibility),
)

@implement_for("gymnasium", "1.0", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool | None = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool | None = None,
**kwargs,
):
import gymnasium
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper

if entry_point is None:
entry_point = cls

entry_point = functools.partial(
_TorchRLGymnasiumWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
if autoreset is not None:
raise TypeError("autoreset is only compatible with gymnasium<1.0.")
if apply_api_compatibility is not None:
raise TypeError("apply_api_compatibility is only compatible with gymnasium<1.0.")
return gymnasium.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down
6 changes: 5 additions & 1 deletion torchrl/envs/libs/_gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def step(self, action): # noqa: F811
return out

@implement_for("gymnasium")
def reset(self): # noqa: F811
def reset(self, seed=None, options=None): # noqa: F811
if seed is not None:
self.torchrl_env.set_seed(seed)
if options is not None:
raise TypeError("options is not supported in torchrl envs.")
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict
if self.info_keys:
Expand Down

0 comments on commit 08c00b1

Please sign in to comment.