Skip to content

Commit 32bcd92

Browse files
Add Monte Carlo Control examples (MC-ES, On-policy ε-soft) with GridWorld + tests
1 parent d2687f9 commit 32bcd92

File tree

4 files changed

+229
-0
lines changed

4 files changed

+229
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# ch5_monte_carlo/examples/mc_control_es_gridworld.py
2+
import numpy as np
3+
from collections import defaultdict
4+
from ch2_rl_formulation.gridworld import GridWorld4x4 # your existing env
5+
6+
ACTIONS = [0, 1, 2, 3] # R, L, D, U (consistent with your env)
7+
8+
def random_start(env: GridWorld4x4):
9+
s = env.S[np.random.randint(len(env.S))]
10+
a = np.random.choice(ACTIONS)
11+
return s, a
12+
13+
def step(env: GridWorld4x4, s, a):
14+
# env exposes P[s_idx][a] -> list of (prob, s', r)
15+
s_idx = env.s2i[s]
16+
trans = env.P[s_idx][a]
17+
probs = [p for (p, _, _) in trans]
18+
i = np.random.choice(len(trans), p=probs)
19+
_, sp_idx, r = trans[i]
20+
return env.i2s[sp_idx], r
21+
22+
def generate_episode_es(env, pi, gamma=1.0):
23+
s0, a0 = random_start(env) # Exploring start
24+
episode = [(s0, a0, None)]
25+
s = s0
26+
a = a0
27+
done = (s == env.goal)
28+
rewards = []
29+
while not done:
30+
sp, r = step(env, s, a)
31+
rewards.append(r)
32+
s = sp
33+
if s == env.goal:
34+
break
35+
# follow current policy after the start
36+
a = pi[s]
37+
episode.append((s, a, None))
38+
return episode, rewards # rewards aligned with transitions
39+
40+
def mc_es_control(env, episodes=5000, gamma=0.9):
41+
Q = defaultdict(lambda: 0.0)
42+
N = defaultdict(int)
43+
# start with arbitrary deterministic policy
44+
pi = {s: np.random.choice(ACTIONS) for s in env.S}
45+
for _ in range(episodes):
46+
ep, rewards = generate_episode_es(env, pi, gamma)
47+
G = 0.0
48+
visited = set()
49+
# process backwards
50+
for t in range(len(ep) - 1, -1, -1):
51+
s, a, _ = ep[t]
52+
G = gamma * G + rewards[t] if t < len(rewards) else G
53+
if (s, a) not in visited:
54+
N[(s, a)] += 1
55+
Q[(s, a)] += (G - Q[(s, a)]) / N[(s, a)] # incremental mean
56+
visited.add((s, a))
57+
# greedy improvement
58+
best_a = max(ACTIONS, key=lambda act: Q[(s, act)])
59+
pi[s] = best_a
60+
return Q, pi
61+
62+
if __name__ == "__main__":
63+
np.random.seed(0)
64+
env = GridWorld4x4(step_reward=0.0, goal=(0, 3))
65+
Q, pi = mc_es_control(env, episodes=3000, gamma=0.9)
66+
# print a small slice of the learned greedy policy arrows
67+
arrows = {0: "→", 1: "←", 2: "↓", 3: "↑"}
68+
for i in range(env.n):
69+
row = []
70+
for j in range(env.n):
71+
s = (i, j)
72+
row.append(" G " if s == env.goal else f" {arrows[pi[s]]} ")
73+
print("".join(row))
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2+
3+
import numpy as np
4+
from collections import defaultdict
5+
from ch2_rl_formulation.gridworld import GridWorld4x4 # reuse Chapter 2 env
6+
7+
ACTIONS = [0, 1, 2, 3] # R, L, D, U
8+
9+
def step(env: GridWorld4x4, s, a):
10+
"""Sample one step given state and action from env's transition model."""
11+
s_idx = env.s2i[s]
12+
trans = env.P[s_idx][a]
13+
probs = [p for (p, _, _) in trans]
14+
i = np.random.choice(len(trans), p=probs)
15+
_, sp_idx, r = trans[i]
16+
return env.i2s[sp_idx], r
17+
18+
def generate_episode(env, pi, epsilon=0.1, gamma=0.9):
19+
"""Generate one episode following epsilon-soft policy pi."""
20+
s = env.S[np.random.randint(len(env.S))]
21+
episode, rewards = [], []
22+
done = (s == env.goal)
23+
while not done:
24+
# choose action epsilon-greedily
25+
if np.random.rand() < epsilon:
26+
a = np.random.choice(ACTIONS)
27+
else:
28+
a = max(ACTIONS, key=lambda act: pi[(s, act)])
29+
sp, r = step(env, s, a)
30+
episode.append((s, a))
31+
rewards.append(r)
32+
s = sp
33+
done = (s == env.goal)
34+
return episode, rewards
35+
36+
def mc_control_onpolicy(env, episodes=5000, gamma=0.9, epsilon=0.1):
37+
"""On-policy MC control with epsilon-soft policies."""
38+
Q = defaultdict(float)
39+
N = defaultdict(int)
40+
pi = {(s, a): 1.0/len(ACTIONS) for s in env.S for a in ACTIONS} # uniform start
41+
42+
for _ in range(episodes):
43+
ep, rewards = generate_episode(env, pi, epsilon, gamma)
44+
G, visited = 0.0, set()
45+
# backward return computation
46+
for t in range(len(ep) - 1, -1, -1):
47+
s, a = ep[t]
48+
G = gamma * G + rewards[t]
49+
if (s, a) not in visited:
50+
N[(s, a)] += 1
51+
Q[(s, a)] += (G - Q[(s, a)]) / N[(s, a)]
52+
visited.add((s, a))
53+
# policy improvement: epsilon-greedy
54+
best_a = max(ACTIONS, key=lambda act: Q[(s, act)])
55+
for act in ACTIONS:
56+
if act == best_a:
57+
pi[(s, act)] = 1 - epsilon + epsilon/len(ACTIONS)
58+
else:
59+
pi[(s, act)] = epsilon/len(ACTIONS)
60+
return Q, pi
61+
62+
if __name__ == "__main__":
63+
np.random.seed(1)
64+
env = GridWorld4x4(step_reward=0.0, goal=(0, 3))
65+
Q, pi = mc_control_onpolicy(env, episodes=3000, gamma=0.9, epsilon=0.1)
66+
67+
# Print learned greedy policy (arrows)
68+
arrows = {0: "→", 1: "←", 2: "↓", 3: "↑"}
69+
for i in range(env.n):
70+
row = []
71+
for j in range(env.n):
72+
s = (i, j)
73+
if s == env.goal:
74+
row.append(" G ")
75+
else:
76+
# choose most probable action
77+
best_a = max(ACTIONS, key=lambda a: pi[(s, a)])
78+
row.append(f" {arrows[best_a]} ")
79+
print("".join(row))

ch5_monte_carlo/tests/__init__.py

Whitespace-only changes.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# ch5_monte_carlo/tests/test_mc_control.py
2+
import numpy as np
3+
4+
from ch2_rl_formulation.gridworld import GridWorld4x4
5+
from ch5_monte_carlo.examples.mc_control_es_gridworld import mc_es_control
6+
from ch5_monte_carlo.examples.mc_control_onpolicy_gridworld import mc_control_onpolicy, ACTIONS
7+
8+
ARROWS = {0: "→", 1: "←", 2: "↓", 3: "↑"}
9+
10+
def rollout_greedy_es(env: GridWorld4x4, pi, max_steps=64):
11+
"""Roll out deterministic greedy policy 'pi' (state -> action)."""
12+
s = env.S[np.random.randint(len(env.S))]
13+
steps = 0
14+
while s != env.goal and steps < max_steps:
15+
a = pi[s]
16+
s_idx = env.s2i[s]
17+
trans = env.P[s_idx][a]
18+
probs = [p for (p, _, _) in trans]
19+
i = np.random.choice(len(trans), p=probs)
20+
_, sp_idx, _ = trans[i]
21+
s = env.i2s[sp_idx]
22+
steps += 1
23+
return s == env.goal, steps
24+
25+
def greedy_from_soft(pi_soft, s):
26+
"""Pick argmax_a pi(a|s) from dict keyed by (s,a)."""
27+
return max(ACTIONS, key=lambda a: pi_soft[(s, a)])
28+
29+
def rollout_greedy_from_soft(env: GridWorld4x4, pi_soft, max_steps=64):
30+
"""Roll out using greedy action from ε-soft policy probabilities."""
31+
s = env.S[np.random.randint(len(env.S))]
32+
steps = 0
33+
while s != env.goal and steps < max_steps:
34+
a = greedy_from_soft(pi_soft, s)
35+
s_idx = env.s2i[s]
36+
trans = env.P[s_idx][a]
37+
probs = [p for (p, _, _) in trans]
38+
i = np.random.choice(len(trans), p=probs)
39+
_, sp_idx, _ = trans[i]
40+
s = env.i2s[sp_idx]
41+
steps += 1
42+
return s == env.goal, steps
43+
44+
def success_rate(trial_fn, trials=100):
45+
succ = 0
46+
total_steps = 0
47+
for _ in range(trials):
48+
ok, steps = trial_fn()
49+
succ += int(ok)
50+
total_steps += steps
51+
return succ / trials, total_steps / trials
52+
53+
def test_mc_es_gridworld_reaches_goal():
54+
np.random.seed(0)
55+
env = GridWorld4x4(step_reward=0.0, goal=(0, 3))
56+
57+
# Train MC-ES (keep episodes modest so CI stays fast)
58+
Q, pi = mc_es_control(env, episodes=1500, gamma=0.9)
59+
60+
sr, avg_steps = success_rate(lambda: rollout_greedy_es(env, pi), trials=100)
61+
62+
# Expect high success and reasonable path length
63+
assert sr >= 0.9, f"MC-ES success rate too low: {sr:.2f}"
64+
assert avg_steps <= 25, f"MC-ES average steps too high: {avg_steps:.1f}"
65+
66+
def test_mc_onpolicy_gridworld_reaches_goal():
67+
np.random.seed(1)
68+
env = GridWorld4x4(step_reward=0.0, goal=(0, 3))
69+
70+
# Train on-policy MC with ε-soft behavior
71+
Q, pi_soft = mc_control_onpolicy(env, episodes=2000, gamma=0.9, epsilon=0.1)
72+
73+
sr, avg_steps = success_rate(lambda: rollout_greedy_from_soft(env, pi_soft), trials=100)
74+
75+
# Expect robust success with a slightly looser bound than MC-ES
76+
assert sr >= 0.85, f"On-policy MC success rate too low: {sr:.2f}"
77+
assert avg_steps <= 28, f"On-policy MC average steps too high: {avg_steps:.1f}"

0 commit comments

Comments
 (0)