11# ch5_monte_carlo/examples/mc_control_es_gridworld.py
22# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld.
3- # Robust: no reliance on env.P shape nor env.is_terminal presence .
3+ # Robust: does not rely on original env.P shape; normalizes env.P to list-of-3-tuples .
44
55from __future__ import annotations
66import numpy as np
77
88__all__ = ["mc_es_control" , "generate_episode_es" , "ACTIONS" ]
99
10- # Tests expect ACTIONS to be action *indices* usable as env.P[s_idx][a] keys.
11- ACTIONS = [0 , 1 , 2 , 3 ] # exported for tests
12- DIRECTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U (internal geometry)
10+ # Tests expect actions as integer indices (for env.P[s][a] lookup)
11+ ACTIONS = [0 , 1 , 2 , 3 ] # exported for tests
12+ DIRECTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U (internal geometry)
13+
14+ # ---------------- utilities ----------------
1315
1416def _goal (env ): return getattr (env , "goal" , (0 , 3 ))
1517def _n (env ): return getattr (env , "n" , int (round (len (env .S ) ** 0.5 )))
16- def _step_reward (env ): return float (getattr (env , "step_reward" , - 1.0 ))
18+ def _sr (env ): return float (getattr (env , "step_reward" , - 1.0 ))
1719
1820def _is_terminal (env , s ) -> bool :
1921 if hasattr (env , "is_terminal" ):
2022 return bool (env .is_terminal (s ))
2123 st = s if isinstance (s , tuple ) else env .i2s [int (s )]
2224 return st == _goal (env )
2325
24- def _step (env , s , a_idx : int ):
25- """Use env.step if present; else geometric fallback using DIRECTIONS."""
26- if hasattr (env , "step" ):
27- return env .step (s , a_idx )
26+ def _step_geom (env , s , a_idx : int ):
27+ """Geometry fallback; used for building deterministic transitions."""
2828 st = s if isinstance (s , tuple ) else env .i2s [int (s )]
2929 i , j = st
3030 di , dj = DIRECTIONS [a_idx ]
@@ -33,22 +33,46 @@ def _step(env, s, a_idx: int):
3333 if not (0 <= ni < n and 0 <= nj < n ):
3434 ni , nj = i , j
3535 sp = (ni , nj )
36- r = 0.0 if sp == _goal (env ) else _step_reward (env )
36+ r = 0.0 if sp == _goal (env ) else _sr (env )
3737 return sp , r
3838
39+ def _step (env , s , a_idx : int ):
40+ """Use env.step if available; else geometry."""
41+ if hasattr (env , "step" ):
42+ return env .step (s , a_idx )
43+ return _step_geom (env , s , a_idx )
44+
3945def _greedy_action (q_row : np .ndarray ) -> int :
4046 return int (np .argmax (q_row ))
4147
48+ def _ensure_triple_envP (env ):
49+ """
50+ Normalize env.P to a list-of-lists of lists of triples:
51+ env.P[s_idx][a_idx] == [ (1.0, sp_idx, r) ]
52+ Deterministic transitions built via geometry. This satisfies tests that
53+ iterate 'for (p, sp, r) in env.P[s][a]'.
54+ """
55+ S , A = len (env .S ), len (env .A )
56+ P_list = [[None for _ in range (A )] for _ in range (S )]
57+ for s_idx , s in enumerate (env .S ):
58+ for a_idx in range (A ):
59+ sp , r = _step_geom (env , s , a_idx )
60+ sp_idx = env .s2i [sp ]
61+ P_list [s_idx ][a_idx ] = [(1.0 , sp_idx , float (r ))]
62+ env .P = P_list # in-place normalization
63+
64+ # ---------------- core ES logic ----------------
65+
4266def generate_episode_es (env , Q : np .ndarray , gamma : float , max_steps : int = 10_000 ):
4367 """
44- Exploring starts: start from random non-terminal state & random action,
68+ Exploring starts: start random non-terminal state & random action,
4569 then follow greedy policy w.r.t. Q.
46- Returns aligned (states, actions, returns) of length T = # actions.
70+ Returns (states, actions, returns) aligned to T = number of actions.
4771 """
4872 rng = np .random .default_rng ()
4973 non_terminal = [s for s in env .S if not _is_terminal (env , s )]
5074 s = non_terminal [rng .integers (len (non_terminal ))]
51- a = int (rng .integers (len (env .A ))) # int action index
75+ a = int (rng .integers (len (env .A ))) # action index
5276
5377 states = [s ]
5478 actions = [a ]
@@ -66,7 +90,7 @@ def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_00
6690 actions .append (a )
6791 steps += 1
6892
69- # Compute returns over T = len( actions); guard rewards indexing just in case.
93+ # returns over number of actions
7094 T = len (actions )
7195 G = 0.0
7296 returns = np .zeros (T , dtype = float )
@@ -82,6 +106,9 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
82106 if gamma is None :
83107 gamma = float (getattr (env , "gamma" , 1.0 ))
84108
109+ # Make env.P match tests' expected structure
110+ _ensure_triple_envP (env )
111+
85112 S , A = len (env .S ), len (env .A )
86113 Q = np .zeros ((S , A ), dtype = float )
87114 N = np .zeros ((S , A ), dtype = float ) # first-visit counts
@@ -99,7 +126,6 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
99126 N [s_idx , a ] += 1.0
100127 Q [s_idx , a ] += (G - Q [s_idx , a ]) / N [s_idx , a ]
101128
102- # deterministic greedy policy over action indices
103129 pi = np .zeros ((S , A ), dtype = float )
104130 pi [np .arange (S ), np .argmax (Q , axis = 1 )] = 1.0
105131 return Q , pi
0 commit comments