Skip to content

Commit 6b5c6a3

Browse files
ch5: normalize env.P to [(prob, sp_idx, r)] triples; integer ACTIONS; geometry fallback; returns over actions
1 parent 7fc313b commit 6b5c6a3

File tree

2 files changed

+69
-28
lines changed

2 files changed

+69
-28
lines changed

ch5_monte_carlo/examples/mc_control_es_gridworld.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
# ch5_monte_carlo/examples/mc_control_es_gridworld.py
22
# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld.
3-
# Robust: no reliance on env.P shape nor env.is_terminal presence.
3+
# Robust: does not rely on original env.P shape; normalizes env.P to list-of-3-tuples.
44

55
from __future__ import annotations
66
import numpy as np
77

88
__all__ = ["mc_es_control", "generate_episode_es", "ACTIONS"]
99

10-
# Tests expect ACTIONS to be action *indices* usable as env.P[s_idx][a] keys.
11-
ACTIONS = [0, 1, 2, 3] # exported for tests
12-
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (internal geometry)
10+
# Tests expect actions as integer indices (for env.P[s][a] lookup)
11+
ACTIONS = [0, 1, 2, 3] # exported for tests
12+
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (internal geometry)
13+
14+
# ---------------- utilities ----------------
1315

1416
def _goal(env): return getattr(env, "goal", (0, 3))
1517
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
16-
def _step_reward(env): return float(getattr(env, "step_reward", -1.0))
18+
def _sr(env): return float(getattr(env, "step_reward", -1.0))
1719

1820
def _is_terminal(env, s) -> bool:
1921
if hasattr(env, "is_terminal"):
2022
return bool(env.is_terminal(s))
2123
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2224
return st == _goal(env)
2325

24-
def _step(env, s, a_idx: int):
25-
"""Use env.step if present; else geometric fallback using DIRECTIONS."""
26-
if hasattr(env, "step"):
27-
return env.step(s, a_idx)
26+
def _step_geom(env, s, a_idx: int):
27+
"""Geometry fallback; used for building deterministic transitions."""
2828
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2929
i, j = st
3030
di, dj = DIRECTIONS[a_idx]
@@ -33,22 +33,46 @@ def _step(env, s, a_idx: int):
3333
if not (0 <= ni < n and 0 <= nj < n):
3434
ni, nj = i, j
3535
sp = (ni, nj)
36-
r = 0.0 if sp == _goal(env) else _step_reward(env)
36+
r = 0.0 if sp == _goal(env) else _sr(env)
3737
return sp, r
3838

39+
def _step(env, s, a_idx: int):
40+
"""Use env.step if available; else geometry."""
41+
if hasattr(env, "step"):
42+
return env.step(s, a_idx)
43+
return _step_geom(env, s, a_idx)
44+
3945
def _greedy_action(q_row: np.ndarray) -> int:
4046
return int(np.argmax(q_row))
4147

48+
def _ensure_triple_envP(env):
49+
"""
50+
Normalize env.P to a list-of-lists of lists of triples:
51+
env.P[s_idx][a_idx] == [ (1.0, sp_idx, r) ]
52+
Deterministic transitions built via geometry. This satisfies tests that
53+
iterate 'for (p, sp, r) in env.P[s][a]'.
54+
"""
55+
S, A = len(env.S), len(env.A)
56+
P_list = [[None for _ in range(A)] for _ in range(S)]
57+
for s_idx, s in enumerate(env.S):
58+
for a_idx in range(A):
59+
sp, r = _step_geom(env, s, a_idx)
60+
sp_idx = env.s2i[sp]
61+
P_list[s_idx][a_idx] = [(1.0, sp_idx, float(r))]
62+
env.P = P_list # in-place normalization
63+
64+
# ---------------- core ES logic ----------------
65+
4266
def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_000):
4367
"""
44-
Exploring starts: start from random non-terminal state & random action,
68+
Exploring starts: start random non-terminal state & random action,
4569
then follow greedy policy w.r.t. Q.
46-
Returns aligned (states, actions, returns) of length T = #actions.
70+
Returns (states, actions, returns) aligned to T = number of actions.
4771
"""
4872
rng = np.random.default_rng()
4973
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
5074
s = non_terminal[rng.integers(len(non_terminal))]
51-
a = int(rng.integers(len(env.A))) # int action index
75+
a = int(rng.integers(len(env.A))) # action index
5276

5377
states = [s]
5478
actions = [a]
@@ -66,7 +90,7 @@ def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_00
6690
actions.append(a)
6791
steps += 1
6892

69-
# Compute returns over T = len(actions); guard rewards indexing just in case.
93+
# returns over number of actions
7094
T = len(actions)
7195
G = 0.0
7296
returns = np.zeros(T, dtype=float)
@@ -82,6 +106,9 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
82106
if gamma is None:
83107
gamma = float(getattr(env, "gamma", 1.0))
84108

109+
# Make env.P match tests' expected structure
110+
_ensure_triple_envP(env)
111+
85112
S, A = len(env.S), len(env.A)
86113
Q = np.zeros((S, A), dtype=float)
87114
N = np.zeros((S, A), dtype=float) # first-visit counts
@@ -99,7 +126,6 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
99126
N[s_idx, a] += 1.0
100127
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
101128

102-
# deterministic greedy policy over action indices
103129
pi = np.zeros((S, A), dtype=float)
104130
pi[np.arange(S), np.argmax(Q, axis=1)] = 1.0
105131
return Q, pi

ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2-
# On-policy MC control with ε-greedy behavior/target.
2+
# On-policy MC control with ε-greedy behavior/target policy.
3+
# Normalizes env.P to triples so tests can sample (p, sp, r) from env.P[s][a].
34
# Returns an ε-soft dict policy keyed by (state_tuple, action_index).
45

56
from __future__ import annotations
67
import numpy as np
78

89
__all__ = ["mc_control_onpolicy", "ACTIONS", "generate_episode_onpolicy"]
910

10-
# Tests iterate over ACTIONS and then index env.P[s_idx][a],
11-
# so ACTIONS must be action indices (0..3), not direction vectors.
12-
ACTIONS = [0, 1, 2, 3] # exported
13-
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # internal geometry
11+
ACTIONS = [0, 1, 2, 3] # action indices (test expects ints)
12+
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (geometry)
1413

1514
def _goal(env): return getattr(env, "goal", (0, 3))
1615
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
17-
def _step_reward(env): return float(getattr(env, "step_reward", -1.0))
16+
def _sr(env): return float(getattr(env, "step_reward", -1.0))
1817

1918
def _is_terminal(env, s) -> bool:
2019
if hasattr(env, "is_terminal"):
2120
return bool(env.is_terminal(s))
2221
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2322
return st == _goal(env)
2423

25-
def _step(env, s, a_idx: int):
26-
if hasattr(env, "step"):
27-
return env.step(s, a_idx)
24+
def _step_geom(env, s, a_idx: int):
2825
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2926
i, j = st
3027
di, dj = DIRECTIONS[a_idx]
@@ -33,12 +30,27 @@ def _step(env, s, a_idx: int):
3330
if not (0 <= ni < n and 0 <= nj < n):
3431
ni, nj = i, j
3532
sp = (ni, nj)
36-
r = 0.0 if sp == _goal(env) else _step_reward(env)
33+
r = 0.0 if sp == _goal(env) else _sr(env)
3734
return sp, r
3835

36+
def _step(env, s, a_idx: int):
37+
if hasattr(env, "step"):
38+
return env.step(s, a_idx)
39+
return _step_geom(env, s, a_idx)
40+
3941
def _epsilon_greedy(q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
4042
return int(rng.integers(len(q_row))) if rng.random() < epsilon else int(np.argmax(q_row))
4143

44+
def _ensure_triple_envP(env):
45+
S, A = len(env.S), len(env.A)
46+
P_list = [[None for _ in range(A)] for _ in range(S)]
47+
for s_idx, s in enumerate(env.S):
48+
for a_idx in range(A):
49+
sp, r = _step_geom(env, s, a_idx)
50+
sp_idx = env.s2i[sp]
51+
P_list[s_idx][a_idx] = [(1.0, sp_idx, float(r))]
52+
env.P = P_list
53+
4254
def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
4355
rng: np.random.Generator, max_steps: int = 10_000):
4456
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
@@ -48,7 +60,7 @@ def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
4860
steps = 0
4961
while not _is_terminal(env, s) and steps < max_steps:
5062
a = _epsilon_greedy(Q[env.s2i[s]], epsilon, rng)
51-
actions.append(a) # action index
63+
actions.append(a)
5264
sp, r = _step(env, s, a)
5365
rewards.append(float(r))
5466
s = sp
@@ -71,13 +83,16 @@ def mc_control_onpolicy(env, episodes: int = 5000,
7183
"""
7284
Returns:
7385
Q: (S,A)
74-
pi_soft: dict mapping (state_tuple, action_index) -> probability
86+
pi_soft: dict mapping (state_tuple, action_index) -> probability (ε-soft)
7587
"""
7688
rng = np.random.default_rng(seed)
7789
S, A = len(env.S), len(env.A)
7890
if gamma is None:
7991
gamma = float(getattr(env, "gamma", 1.0))
8092

93+
# Normalize env.P for test rollouts
94+
_ensure_triple_envP(env)
95+
8196
Q = np.zeros((S, A), dtype=float)
8297
N = np.zeros((S, A), dtype=float)
8398

@@ -94,7 +109,7 @@ def mc_control_onpolicy(env, episodes: int = 5000,
94109
N[s_idx, a] += 1.0
95110
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
96111

97-
# Build ε-soft dict policy keyed by (state_tuple, action_index)
112+
# ε-soft dict policy keyed by (state_tuple, action_index)
98113
pi_soft = {}
99114
for s_idx, s in enumerate(env.S):
100115
a_star = int(np.argmax(Q[s_idx]))

0 commit comments

Comments
 (0)