Skip to content

Commit e1cf15f

Browse files
Fix ch4 gamma + terminal reward; add MC on-policy impl and exports for ch5
1 parent 80dda58 commit e1cf15f

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

ch4_dynamic_programming/gridworld.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _build_PR(self):
3737
for s_idx, (i, j) in enumerate(self.S):
3838
for a_idx, (di, dj) in enumerate(ACTIONS):
3939
if s_idx == g_idx:
40-
# absorbing terminal
40+
# Absorbing terminal: stay with zero reward
4141
P[s_idx, a_idx, s_idx] = 1.0
4242
R[s_idx, a_idx, s_idx] = 0.0
4343
continue
@@ -48,8 +48,11 @@ def _build_PR(self):
4848

4949
sp_idx = self.s2i[(ni, nj)]
5050
P[s_idx, a_idx, sp_idx] = 1.0
51-
if (ni, nj) == self.goal:
52-
R[s_idx, a_idx, sp_idx] = 0.0 # no penalty entering goal
51+
52+
# KEY CHANGE: entering goal still incurs step cost (-1),
53+
# only staying-in-goal self-loop has 0 reward.
54+
# So do NOT overwrite R[...] = 0.0 here.
55+
# R already has step_reward by default.
5356
return P, R
5457

5558
# -------- environment API (used by ch5 as well) --------
Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,83 @@
1-
# --- add/replace this helper at the top of the file ---
2-
def step(env, s, a):
3-
"""Sample one step given state and action; robust to P formats."""
4-
a = int(a)
5-
s_idx = env.s2i[s] if isinstance(s, tuple) else int(s)
6-
7-
P = env.P
8-
if s_idx in P:
9-
trans = P[s_idx][a]
10-
else:
11-
trans = P[env.i2s[s_idx]][a]
12-
13-
tr = trans[0]
14-
sp = tr.sp
15-
if isinstance(sp, int):
16-
sp = env.i2s[sp]
17-
return sp, tr.r
1+
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2+
import numpy as np
3+
from ch4_dynamic_programming.gridworld import GridWorld4x4
4+
5+
__all__ = ["mc_control_onpolicy", "ACTIONS", "generate_episode_onpolicy"]
6+
7+
# Must match the environment's action ordering
8+
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
9+
10+
def _epsilon_greedy(Q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
11+
if rng.random() < epsilon:
12+
return int(rng.integers(len(Q_row)))
13+
return int(np.argmax(Q_row))
14+
15+
def generate_episode_onpolicy(env: GridWorld4x4, Q: np.ndarray, epsilon: float,
16+
rng: np.random.Generator, max_steps: int = 10_000):
17+
"""Start from a random non-terminal state; follow ε-greedy w.r.t. Q throughout."""
18+
non_terminal = [s for s in env.S if not env.is_terminal(s)]
19+
s = non_terminal[rng.integers(len(non_terminal))]
20+
S, A = len(env.S), len(env.A)
21+
22+
states, actions, rewards = [s], [], [0.0]
23+
steps = 0
24+
while not env.is_terminal(s) and steps < max_steps:
25+
a = _epsilon_greedy(Q[env.s2i[s]], epsilon, rng)
26+
actions.append(a)
27+
sp, r = env.step(s, a)
28+
rewards.append(float(r))
29+
s = sp
30+
states.append(s)
31+
steps += 1
32+
33+
# first-visit returns
34+
gamma = env.gamma
35+
G = 0.0
36+
returns = np.zeros(len(actions), dtype=float)
37+
for t in range(len(actions) - 1, -1, -1):
38+
G = rewards[t + 1] + gamma * G
39+
returns[t] = G
40+
return states[:-1], actions, returns
41+
42+
def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
43+
epsilon: float = 0.1, gamma: float | None = None,
44+
seed: int | None = None):
45+
"""
46+
On-policy Monte Carlo control using ε-greedy behavior/target policy (no exploring starts).
47+
Returns:
48+
Q: (S,A) table
49+
pi: (S,A) deterministic greedy policy derived from Q
50+
"""
51+
rng = np.random.default_rng(seed)
52+
S, A = len(env.S), len(env.A)
53+
if gamma is None:
54+
gamma = float(env.gamma)
55+
56+
Q = np.zeros((S, A), dtype=float)
57+
N = np.zeros((S, A), dtype=float)
58+
59+
for _ in range(episodes):
60+
states, actions, returns = generate_episode_onpolicy(env, Q, epsilon, rng)
61+
seen = set()
62+
for t, (s, a) in enumerate(zip(states, actions)):
63+
s_idx = env.s2i[s]
64+
key = (s_idx, a)
65+
if key in seen:
66+
continue # first-visit MC
67+
seen.add(key)
68+
G = returns[t]
69+
N[s_idx, a] += 1.0
70+
alpha = 1.0 / N[s_idx, a]
71+
Q[s_idx, a] += alpha * (G - Q[s_idx, a])
72+
73+
# deterministic greedy policy
74+
pi = np.zeros((S, A), dtype=float)
75+
pi[np.arange(S), np.argmax(Q, axis=1)] = 1.0
76+
return Q, pi
77+
78+
if __name__ == "__main__":
79+
env = GridWorld4x4(step_reward=-1.0, goal=(0, 3), gamma=1.0)
80+
Q, pi = mc_control_onpolicy(env, episodes=3000, epsilon=0.1, seed=0)
81+
s0 = env.s2i[(0, 0)]
82+
print("Q(start):", Q[s0])
83+
print("Greedy action at start:", int(np.argmax(pi[s0])))

0 commit comments

Comments
 (0)