Skip to content

Commit d8707dc

Browse files
⬆️ Update to gymnasium
Refactor to centralize gym tests
1 parent c22c082 commit d8707dc

15 files changed

+233
-209
lines changed

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ keywords = [
2323

2424

2525
[project.optional-dependencies]
26-
gym = ["gym >= 0.26"]
26+
gym = [
27+
"gymnasium>=1.0.0",
28+
]
2729
gui = ["pygame >= 2.1.0", "pygame-menu >= 4.3.8"]
2830
planning = ["unified_planning[aries,enhsp] >= 1.1.0", "up-enhsp>=0.0.25"]
2931
htmlvis = ["pyvis<=0.3.1"]

src/hcraft/env.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@
271271
"""
272272

273273
import collections
274-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
274+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
275275

276276
import numpy as np
277277

@@ -293,7 +293,7 @@
293293

294294
# Gym is an optional dependency.
295295
try:
296-
import gym
296+
import gymnasium as gym
297297

298298
DiscreteSpace = gym.spaces.Discrete
299299
BoxSpace = gym.spaces.Box
@@ -398,7 +398,9 @@ def action_masks(self) -> np.ndarray:
398398
"""Return boolean mask of valid actions."""
399399
return np.array([t.is_valid(self.state) for t in self.world.transformations])
400400

401-
def step(self, action: int):
401+
def step(
402+
self, action: int | str | np.ndarray
403+
) -> Tuple[np.ndarray, float, bool, bool, dict]:
402404
"""Perform one step in the environment given the index of a wanted transformation.
403405
404406
If the selected transformation can be performed, the state is updated and
@@ -407,6 +409,13 @@ def step(self, action: int):
407409
408410
"""
409411

412+
if isinstance(action, np.ndarray):
413+
if not action.size == 1:
414+
raise TypeError(
415+
"Actions should be integers corresponding the a transformation index"
416+
f", got array with multiple elements:\n{action}."
417+
)
418+
action = action.flatten()[0]
410419
try:
411420
action = int(action)
412421
except (TypeError, ValueError) as e:
@@ -433,7 +442,13 @@ def step(self, action: int):
433442

434443
self.current_score += reward
435444
self.cumulated_score += reward
436-
return self._step_output(reward, terminated, truncated)
445+
return (
446+
self.state.observation,
447+
reward,
448+
terminated,
449+
truncated,
450+
self.infos(),
451+
)
437452

438453
def render(self, mode: Optional[str] = None, **_kwargs) -> Union[str, np.ndarray]:
439454
"""Render the observation of the agent in a format depending on `render_mode`."""
@@ -451,7 +466,7 @@ def reset(
451466
*,
452467
seed: Optional[int] = None,
453468
options: Optional[dict] = None,
454-
) -> np.ndarray:
469+
) -> Tuple[np.ndarray,]:
455470
"""Resets the state of the environement.
456471
457472
Returns:
@@ -472,7 +487,7 @@ def reset(
472487

473488
self.state.reset()
474489
self.purpose.reset()
475-
return self.state.observation
490+
return self.state.observation, self.infos()
476491

477492
def close(self):
478493
"""Closes the environment."""
@@ -540,19 +555,14 @@ def planning_problem(self, **kwargs) -> HcraftPlanningProblem:
540555
"""
541556
return HcraftPlanningProblem(self.state, self.name, self.purpose, **kwargs)
542557

543-
def _step_output(self, reward: float, terminated: bool, truncated: bool):
558+
def infos(self) -> dict:
544559
infos = {
545560
"action_is_legal": self.action_masks(),
546561
"score": self.current_score,
547562
"score_average": self.cumulated_score / self.episodes,
548563
}
549564
infos.update(self._tasks_infos())
550-
return (
551-
self.state.observation,
552-
reward,
553-
terminated or truncated,
554-
infos,
555-
)
565+
return infos
556566

557567
def _tasks_infos(self):
558568
infos = {}

src/hcraft/examples/light_recursive.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# gym is an optional dependency
3131
try:
32-
import gym
32+
import gymnasium as gym
3333

3434
gym.register(
3535
id="LightRecursiveHcraft-v1",
@@ -41,7 +41,6 @@
4141

4242

4343
class LightRecursiveHcraftEnv(HcraftEnv):
44-
4544
"""LightRecursive environment."""
4645

4746
def __init__(self, n_items: int = 6, n_required_previous: int = 2, **kwargs):

src/hcraft/examples/minecraft/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
.. include:: ../../../../docs/images/requirements_graphs/MineHcraft.html
88
</div>
99
"""
10+
1011
from typing import Optional
1112

1213
import hcraft.examples.minecraft.items as items
@@ -21,7 +22,7 @@
2122

2223
# gym is an optional dependency
2324
try:
24-
import gym
25+
import gymnasium as gym
2526

2627
ENV_PATH = "hcraft.examples.minecraft.env:MineHcraftEnv"
2728

src/hcraft/examples/minicraft/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""# MiniHCraft environments
22
3-
List of environments representing abstractions from
3+
List of environments representing abstractions from
44
[minigrid environments](https://minigrid.farama.org/environments/minigrid/).
55
66
See submodules for each individual environement:
@@ -77,7 +77,7 @@
7777
MINICRAFT_GYM_ENVS = []
7878

7979
try:
80-
import gym
80+
import gymnasium as gym
8181

8282
ENV_PATH = "hcraft.examples.minicraft"
8383

src/hcraft/examples/random_simple/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# gym is an optional dependency
66
try:
7-
import gym
7+
import gymnasium as gym
88

99
gym.register(
1010
id="RandomHcraft-v1",

src/hcraft/examples/recursive.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" # Recursive HierarchyCraft Environments
1+
"""# Recursive HierarchyCraft Environments
22
33
The goal of the environment is to get the last item.
44
But each item requires all the previous items,
@@ -29,7 +29,7 @@
2929

3030
# gym is an optional dependency
3131
try:
32-
import gym
32+
import gymnasium as gym
3333

3434
gym.register(
3535
id="RecursiveHcraft-v1",
@@ -41,7 +41,6 @@
4141

4242

4343
class RecursiveHcraftEnv(HcraftEnv):
44-
4544
"""RecursiveHcraft Environment"""
4645

4746
def __init__(self, n_items: int = 6, **kwargs):

src/hcraft/examples/tower.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from hcraft.task import GetItemTask
4242

4343
try:
44-
import gym
44+
import gymnasium as gym
4545

4646
gym.register(
4747
id="TowerHcraft-v1",
@@ -53,7 +53,6 @@
5353

5454

5555
class TowerHcraftEnv(HcraftEnv):
56-
5756
"""Tower, a tower-structured hierarchical Environment.
5857
5958
Item of given layer requires all items of the previous.

src/hcraft/examples/treasure/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# gym is an optional dependency
1616
try:
17-
import gym
17+
import gymnasium as gym
1818

1919
gym.register(
2020
id="Treasure-v1",

tests/examples/minecraft/test_gym_make.py

-134
This file was deleted.

tests/examples/test_gym_make.py

Whitespace-only changes.

tests/examples/test_random.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gymnasium
12
import pytest
23
import pytest_check as check
34

@@ -6,7 +7,6 @@
67

78

89
class TestRandomHcraft:
9-
1010
"""Test the RandomHcraft environment"""
1111

1212
@pytest.fixture(autouse=True)
@@ -15,16 +15,6 @@ def setup_method(self):
1515
self.n_items_per_n_inputs = {0: 1, 1: 5, 2: 10, 4: 1}
1616
self.n_items = sum(self.n_items_per_n_inputs.values())
1717

18-
def test_gym_make(self):
19-
gym = pytest.importorskip("gym")
20-
env: RandomHcraftEnv = gym.make(
21-
"RandomHcraft-v1",
22-
n_items_per_n_inputs=self.n_items_per_n_inputs,
23-
seed=42,
24-
)
25-
check.equal(len(env.world.items), self.n_items)
26-
check.equal(env.seed, 42)
27-
2818
def test_same_seed_same_requirements_graph(self):
2919
env = RandomHcraftEnv(self.n_items_per_n_inputs, seed=42)
3020
env2 = RandomHcraftEnv(self.n_items_per_n_inputs, seed=42)

0 commit comments

Comments
 (0)