271
271
"""
272
272
273
273
import collections
274
- from typing import TYPE_CHECKING , Dict , List , Optional , Union
274
+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union
275
275
276
276
import numpy as np
277
277
293
293
294
294
# Gym is an optional dependency.
295
295
try :
296
- import gym
296
+ import gymnasium as gym
297
297
298
298
DiscreteSpace = gym .spaces .Discrete
299
299
BoxSpace = gym .spaces .Box
@@ -398,7 +398,9 @@ def action_masks(self) -> np.ndarray:
398
398
"""Return boolean mask of valid actions."""
399
399
return np .array ([t .is_valid (self .state ) for t in self .world .transformations ])
400
400
401
- def step (self , action : int ):
401
+ def step (
402
+ self , action : int | str | np .ndarray
403
+ ) -> Tuple [np .ndarray , float , bool , bool , dict ]:
402
404
"""Perform one step in the environment given the index of a wanted transformation.
403
405
404
406
If the selected transformation can be performed, the state is updated and
@@ -407,6 +409,13 @@ def step(self, action: int):
407
409
408
410
"""
409
411
412
+ if isinstance (action , np .ndarray ):
413
+ if not action .size == 1 :
414
+ raise TypeError (
415
+ "Actions should be integers corresponding the a transformation index"
416
+ f", got array with multiple elements:\n { action } ."
417
+ )
418
+ action = action .flatten ()[0 ]
410
419
try :
411
420
action = int (action )
412
421
except (TypeError , ValueError ) as e :
@@ -433,7 +442,13 @@ def step(self, action: int):
433
442
434
443
self .current_score += reward
435
444
self .cumulated_score += reward
436
- return self ._step_output (reward , terminated , truncated )
445
+ return (
446
+ self .state .observation ,
447
+ reward ,
448
+ terminated ,
449
+ truncated ,
450
+ self .infos (),
451
+ )
437
452
438
453
def render (self , mode : Optional [str ] = None , ** _kwargs ) -> Union [str , np .ndarray ]:
439
454
"""Render the observation of the agent in a format depending on `render_mode`."""
@@ -451,7 +466,7 @@ def reset(
451
466
* ,
452
467
seed : Optional [int ] = None ,
453
468
options : Optional [dict ] = None ,
454
- ) -> np .ndarray :
469
+ ) -> Tuple [ np .ndarray ,] :
455
470
"""Resets the state of the environement.
456
471
457
472
Returns:
@@ -472,7 +487,7 @@ def reset(
472
487
473
488
self .state .reset ()
474
489
self .purpose .reset ()
475
- return self .state .observation
490
+ return self .state .observation , self . infos ()
476
491
477
492
def close (self ):
478
493
"""Closes the environment."""
@@ -540,19 +555,14 @@ def planning_problem(self, **kwargs) -> HcraftPlanningProblem:
540
555
"""
541
556
return HcraftPlanningProblem (self .state , self .name , self .purpose , ** kwargs )
542
557
543
- def _step_output (self , reward : float , terminated : bool , truncated : bool ) :
558
+ def infos (self ) -> dict :
544
559
infos = {
545
560
"action_is_legal" : self .action_masks (),
546
561
"score" : self .current_score ,
547
562
"score_average" : self .cumulated_score / self .episodes ,
548
563
}
549
564
infos .update (self ._tasks_infos ())
550
- return (
551
- self .state .observation ,
552
- reward ,
553
- terminated or truncated ,
554
- infos ,
555
- )
565
+ return infos
556
566
557
567
def _tasks_infos (self ):
558
568
infos = {}
0 commit comments