11import numpy as np
2+ from collections import namedtuple
23
3- ACTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U
4+ # Actions: Right, Left, Down, Up
5+ ACTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )]
6+ TR = namedtuple ("TR" , ["p" , "sp" , "r" , "done" ]) # transition record
47
58class GridWorld4x4 :
9+ """4x4 GridWorld with step cost on every move, including terminal entry.
10+
11+ Convention:
12+ - States are indexed 0..15 in row-major order; i2s maps index->(i,j).
13+ - The goal is absorbing: once there, any action keeps you there with reward 0.
14+ - Each attempted move costs `step_reward` (e.g., -1), including the final move into the goal.
15+ - Bumping into a wall leaves you in place and still incurs the step cost.
16+ - Transition kernel P is indexed by state index: P[s][a] -> [TR(p, sp, r, done)].
17+ """
18+
619 def __init__ (self , step_reward : float = - 1.0 , goal = (0 , 3 )):
720 self .n = 4
821 self .goal = tuple (goal )
922 self .step_reward = float (step_reward )
1023
24+ # Enumerate states as (i,j) tuples; provide index mappings
1125 self .S = [(i , j ) for i in range (self .n ) for j in range (self .n )]
1226 self .s2i = {s : k for k , s in enumerate (self .S )}
1327 self .i2s = {k : s for s , k in self .s2i .items ()}
14- self .A = list (range (len (ACTIONS )))
1528
16- # Transition kernel: P[s][a] = [(prob, s_next, reward, done)]
17- self .P = {s : {a : [] for a in self .A } for s in self .S }
18- self ._build_PR ()
29+ self .A = list (range (len (ACTIONS ))) # 0..3
30+ self ._goal_idx = self .s2i [self .goal ]
1931
20- # --- compatibility helpers expected by tests ---
32+ # Transition kernel with integer state indices
33+ self .P = {s_idx : {a : [] for a in self .A } for s_idx in range (self .num_states )}
34+ self ._build_P ()
35+
36+ # --- properties expected by tests ---
2137 @property
2238 def num_states (self ) -> int :
2339 return len (self .S )
@@ -26,41 +42,50 @@ def num_states(self) -> int:
2642 def num_actions (self ) -> int :
2743 return len (self .A )
2844
29- # ---------------- internal helpers ----------------
30- def _in_bounds (self , i , j ):
45+ @property
46+ def terminal (self ) -> int :
47+ """Index of the terminal (goal) state."""
48+ return self ._goal_idx
49+
50+ # --- helpers ---
51+ def _in_bounds (self , i , j ) -> bool :
3152 return 0 <= i < self .n and 0 <= j < self .n
3253
33- def _next_state (self , s , a ):
54+ def _next_state_tuple (self , s_tuple , a ):
3455 di , dj = ACTIONS [a ]
35- i , j = s
56+ i , j = s_tuple
3657 ni , nj = i + di , j + dj
3758 if not self ._in_bounds (ni , nj ):
38- return (i , j ) # bump: stay in place
59+ return (i , j ) # bump into wall: stay
3960 return (ni , nj )
4061
41- def _build_PR (self ):
42- """
43- Convention:
44- - Every attempted move costs step_reward (e.g., -1), INCLUDING the final move into goal.
45- - Goal is absorbing: once there, any action yields (goal, 0, done=True).
46- """
47- goal = self .goal
48- for s in self .S :
49- if s == goal :
62+ def _build_P (self ):
63+ goal_idx = self ._goal_idx
64+
65+ for s_idx in range (self .num_states ):
66+ s_tuple = self .i2s [s_idx ]
67+
68+ if s_idx == goal_idx :
69+ # Absorbing terminal: reward 0 thereafter
5070 for a in self .A :
51- self .P [s ][a ] = [(1.0 , goal , 0.0 , True )]
71+ self .P [s_idx ][a ] = [TR (1.0 , goal_idx , 0.0 , True )]
5272 continue
5373
5474 for a in self .A :
55- s_next = self ._next_state (s , a )
56- done = (s_next == goal )
57- reward = self .step_reward # charge cost even on final transition
58- self .P [s ][a ] = [(1.0 , s_next , reward , done )]
59-
60- # Optional simulator
61- def step (self , s , a ):
62- trans = self .P [s ][a ]
63- probs = [p for p , _ , _ , _ in trans ]
64- idx = np .random .choice (len (trans ), p = probs )
65- _ , s_next , r , done = trans [idx ]
66- return s_next , r , done
75+ next_tuple = self ._next_state_tuple (s_tuple , a )
76+ sp_idx = self .s2i [next_tuple ]
77+ done = (sp_idx == goal_idx )
78+ # Charge step cost even when entering terminal
79+ r = self .step_reward
80+ self .P [s_idx ][a ] = [TR (1.0 , sp_idx , r , done )]
81+
82+ # Optional index-based step (used by some examples/tests)
83+ def step_idx (self , s_idx : int , a : int ):
84+ trans = self .P [s_idx ][a ]
85+ if len (trans ) == 1 :
86+ tr = trans [0 ]
87+ else :
88+ probs = [t .p for t in trans ]
89+ idx = np .random .choice (len (trans ), p = probs )
90+ tr = trans [idx ]
91+ return tr .sp , tr .r , tr .done
0 commit comments