Skip to content

Commit c2717b2

Browse files
added_book_examples
1 parent a398a28 commit c2717b2

17 files changed

+269
-232
lines changed

ch2_rl_formulation/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Chapter 2 — The RL Problem Formulation
2+
3+
Implements: MDP formalism, Bellman expectation & optimality, gridworld, greedy/ε-greedy, value iteration.
4+
Includes numeric examples (5.23 and 4.58), demos, visualizations, and tests.
5+
6+
## Quickstart
7+
```bash
8+
python -m ch2_rl_formulation.examples.numeric_checks
9+
python -m ch2_rl_formulation.examples.gridworld_demo
10+
python -m ch2_rl_formulation.examples.plot_value_and_policy
11+
```
12+
13+
## Layout
14+
- `gridworld.py` — 4×4 deterministic GridWorld.
15+
- `evaluation.py` — policy evaluation (deterministic & stochastic), `q_from_v`, `greedy_from_q`.
16+
- `policies.py` — deterministic & ε-greedy helpers.
17+
- `value_iteration.py` — value iteration, extract greedy policy.
18+
- `visualize.py` — single-plot, matplotlib-based visuals (no explicit colors).
19+
- `examples/` — boxed examples and demos.
20+
- `tests/` — sanity checks tied to the chapter.

ch2_rl_formulation/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
from .gridworld import GridWorld4x4
2-
from .evaluation import policy_evaluation, q_from_v
3-
from .policies import greedy_toward_goal_policy, greedy_from_q, epsilon_greedy
1+

ch2_rl_formulation/evaluation.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,54 @@
1+
from __future__ import annotations
12
import numpy as np
3+
from .gridworld import GridWorld4x4
24

3-
def policy_evaluation(S, A, P, R, pi, gamma=1.0, theta=1e-10):
4-
"""
5-
Tabular policy evaluation for general R(s,a,s').
6-
Inputs:
7-
- S: list of states
8-
- A: list of actions
9-
- P: [|S|, |A|, |S'|] transition probabilities
10-
- R: [|S|, |A|, |S'|] rewards
11-
- pi: [|S|, |A|] policy (row-stochastic; can be deterministic one-hot)
12-
- gamma: discount factor
13-
- theta: convergence threshold (max delta)
14-
Returns:
15-
- V: np.ndarray of shape [|S|]
16-
"""
17-
nS, nA, nSp = P.shape
18-
assert nS == len(S) and nA == len(A) and nSp == nS
19-
assert pi.shape == (nS, nA)
5+
def policy_evaluation(env: GridWorld4x4,
6+
pi_actions: np.ndarray,
7+
gamma: float = 0.9,
8+
theta: float = 1e-8,
9+
max_iter: int = 10000) -> np.ndarray:
10+
S = env.num_states
11+
V = np.zeros(S, dtype=float)
12+
for _ in range(max_iter):
13+
delta = 0.0
14+
for s in range(S):
15+
a = int(pi_actions[s])
16+
v_new = sum(tr.p * (tr.r + gamma * V[tr.sp]) for tr in env.P[s][a])
17+
delta = max(delta, abs(v_new - V[s]))
18+
V[s] = v_new
19+
if delta < theta:
20+
break
21+
return V
2022

21-
V = np.zeros(nS, dtype=float)
22-
while True:
23+
def policy_evaluation_stochastic(env: GridWorld4x4,
24+
pi_probs: np.ndarray,
25+
gamma: float = 0.9,
26+
theta: float = 1e-8,
27+
max_iter: int = 10000) -> np.ndarray:
28+
S, A = env.num_states, env.num_actions
29+
V = np.zeros(S, dtype=float)
30+
for _ in range(max_iter):
2331
delta = 0.0
24-
V_new = np.zeros_like(V)
25-
for s in range(nS):
26-
val = 0.0
27-
for a in range(nA):
28-
p_sa = pi[s, a]
29-
if p_sa == 0.0:
32+
for s in range(S):
33+
v_new = 0.0
34+
for a in range(A):
35+
pa = pi_probs[s, a]
36+
if pa == 0.0:
3037
continue
31-
val += p_sa * np.sum(P[s, a, :] * (R[s, a, :] + gamma * V))
32-
V_new[s] = val
33-
delta = max(delta, abs(V_new[s] - V[s]))
34-
V = V_new
38+
v_new += pa * sum(tr.p * (tr.r + gamma * V[tr.sp]) for tr in env.P[s][a])
39+
delta = max(delta, abs(v_new - V[s]))
40+
V[s] = v_new
3541
if delta < theta:
3642
break
3743
return V
3844

39-
def q_from_v(S, A, P, R, V, gamma=1.0):
40-
nS, nA, _ = P.shape
41-
Q = np.zeros((nS, nA), dtype=float)
42-
for s in range(nS):
43-
for a in range(nA):
44-
Q[s, a] = np.sum(P[s, a, :] * (R[s, a, :] + gamma * V))
45+
def q_from_v(env: GridWorld4x4, V: np.ndarray, gamma: float = 0.9) -> np.ndarray:
46+
S, A = env.num_states, env.num_actions
47+
Q = np.zeros((S, A), dtype=float)
48+
for s in range(S):
49+
for a in range(A):
50+
Q[s, a] = sum(tr.p * (tr.r + gamma * V[tr.sp]) for tr in env.P[s][a])
4551
return Q
52+
53+
def greedy_from_q(Q: np.ndarray) -> np.ndarray:
54+
return np.argmax(Q, axis=1)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,11 @@
11
import numpy as np
2-
from ch2_rl_formulation.gridworld import GridWorld4x4
3-
from ch2_rl_formulation.policies import greedy_toward_goal_policy
4-
from ch2_rl_formulation.evaluation import policy_evaluation, q_from_v
5-
6-
def main():
7-
env = GridWorld4x4(step_reward=-1, goal_reward=0, goal=(0,3))
8-
S, A = env.states(), env.actions()
9-
P, R = env.P_tensor(), env.R_tensor()
10-
pi = greedy_toward_goal_policy(env)
11-
V = policy_evaluation(S, A, P, R, pi, gamma=1.0, theta=1e-12)
12-
Vgrid = np.array(V).reshape(4,4)
13-
print("V_pi grid (goal top-right):\n", Vgrid)
14-
expected = np.array([
15-
[-4, -3, -2, 0],
16-
[-5, -4, -3, -1],
17-
[-6, -5, -4, -2],
18-
[-7, -6, -5, -3],
19-
], dtype=float)
20-
assert np.allclose(Vgrid, expected, atol=1e-12)
21-
Q = q_from_v(S, A, P, R, V, gamma=1.0)
22-
s_bl = env.state_index(3,0)
23-
a = {name: env.action_index(name) for name in A}
24-
print("\nQ at bottom-left (row=3,col=0):")
25-
for name in A:
26-
print(f"{name:>5}: {Q[s_bl, a[name]]:6.2f}")
27-
assert abs(Q[s_bl, a["up"]] - (-7)) < 1e-12
28-
assert abs(Q[s_bl, a["right"]]- (-7)) < 1e-12
29-
assert abs(Q[s_bl, a["left"]] - (-8)) < 1e-12
30-
assert abs(Q[s_bl, a["down"]] - (-8)) < 1e-12
31-
print("\nAll checks PASS.")
2+
from ..gridworld import GridWorld4x4
3+
from ..value_iteration import value_iteration
324

335
if __name__ == "__main__":
34-
main()
6+
env = GridWorld4x4(step_reward=-1.0, goal=(0, 3))
7+
V_star, pi_star = value_iteration(env, gamma=0.9, theta=1e-10)
8+
print("Optimal V* (gamma=0.9):")
9+
print(np.round(V_star.reshape(4, 4), 2))
10+
print("\nGreedy π* (0:R,1:L,2:D,3:U):")
11+
print(pi_star.reshape(4, 4))
Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
1-
def approx(x, y, tol=1e-6):
2-
return abs(x - y) < tol
1+
import numpy as np
2+
from ..gridworld import GridWorld4x4
3+
from ..evaluation import policy_evaluation, q_from_v, greedy_from_q
34

4-
def main():
5-
ok = True
6-
g0 = 1 + 0.9*2 + (0.9**2)*3
7-
print("G0 =", g0); ok &= approx(g0, 5.23)
8-
v = -1 + 0.9*(-1) + (0.9**2)*(-1) + (0.9**3)*10
9-
print("v =", v); ok &= approx(v, 4.58)
10-
v_pe = 2 + 0.9*4
11-
print("v_pe =", v_pe); ok &= approx(v_pe, 5.6, 1e-12)
12-
q_pe = 1 + 0.9*3
13-
print("q_pe =", q_pe); ok &= approx(q_pe, 3.7, 1e-12)
14-
vopt = max(2 + 0.9*5, 1 + 0.9*8)
15-
print("v* =", vopt); ok &= approx(vopt, 8.2, 1e-12)
16-
qopt = 2 + 0.9*6
17-
print("q* =", qopt); ok &= approx(qopt, 7.4, 1e-12)
18-
v4 = sum((0.9**k)*(-1) for k in range(4)) + (0.9**4)*10
19-
print("v*(4 steps) =", v4); ok &= abs(v4 - 3.122) < 1e-3
20-
print("\nALL NUMERIC EXAMPLES:", "PASS" if ok else "FAIL")
21-
if not ok: raise SystemExit(1)
5+
def discounted_return_example():
6+
r = [1, 2, 3]; gamma = 0.9
7+
return r[0] + gamma*r[1] + (gamma**2)*r[2]
8+
9+
def state_value_example():
10+
gamma = 0.9
11+
return -1 + gamma*(-1) + (gamma**2)*(-1) + (gamma**3)*10
12+
13+
def gridworld_vq_under_fixed_policy(gamma: float = 1.0):
14+
env = GridWorld4x4(step_reward=-1.0, goal=(0, 3))
15+
pi = np.zeros(env.num_states, dtype=int)
16+
for s in range(env.num_states):
17+
i, j = env.i2s[s]
18+
if s == env.terminal:
19+
pi[s] = 0; continue
20+
pi[s] = 0 if j < 3 else 3
21+
V = policy_evaluation(env, pi, gamma=gamma, theta=1e-10)
22+
Q = q_from_v(env, V, gamma=gamma)
23+
pi_greedy = greedy_from_q(Q)
24+
return V.reshape(4,4), Q, pi_greedy.reshape(4,4)
2225

2326
if __name__ == "__main__":
24-
main()
27+
print("G0 example (should be 5.23):", round(discounted_return_example(), 2))
28+
print("v_pi example (should be 4.58):", round(state_value_example(), 2))
29+
V, Q, pi_g = gridworld_vq_under_fixed_policy(gamma=1.0)
30+
print("\nGridworld V under greedy-to-goal policy (gamma=1):")
31+
print(np.array_str(np.round(V, 0), precision=0))
32+
print("\nGreedy-from-Q policy indices (0:R,1:L,2:D,3:U):")
33+
print(pi_g)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
from ..gridworld import GridWorld4x4
3+
from ..value_iteration import value_iteration
4+
from ..visualize import plot_value_grid, plot_greedy_policy
5+
6+
if __name__ == "__main__":
7+
env = GridWorld4x4(step_reward=-1.0, goal=(0,3))
8+
V, pi = value_iteration(env, gamma=0.9, theta=1e-10)
9+
plot_value_grid(V, title="Optimal V* (gamma=0.9)")
10+
plot_greedy_policy(pi, title="Greedy Policy π*")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

ch2_rl_formulation/gridworld.py

Lines changed: 49 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,56 @@
1+
from __future__ import annotations
2+
from dataclasses import dataclass
3+
from typing import Dict, Tuple, List
14
import numpy as np
2-
from typing import List, Tuple
35

4-
class GridWorld4x4:
5-
"""
6-
4x4 deterministic GridWorld for Chapter 2.
7-
- Step reward: -1
8-
- Terminal goal: (0,3) top-right, value 0, no outgoing actions
9-
- Actions: up, right, down, left
10-
- Deterministic transitions; hitting a wall = self-transition
11-
"""
6+
ACTIONS: List[Tuple[int, int]] = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
127

13-
ACTIONS = ("up", "right", "down", "left")
8+
@dataclass(frozen=True)
9+
class Transition:
10+
s: int
11+
a: int
12+
sp: int
13+
r: float
14+
p: float
1415

15-
def __init__(
16-
self,
17-
step_reward: float = -1.0,
18-
goal_reward: float = 0.0,
19-
deterministic: bool = True,
20-
self_transition_on_wall: bool = True,
21-
goal: Tuple[int, int] = (0, 3),
22-
):
23-
assert deterministic, "Only deterministic dynamics supported in this minimal build."
24-
self.h, self.w = 4, 4
16+
class GridWorld4x4:
17+
"""Deterministic 4x4 gridworld with an absorbing goal (reward 0)."""
18+
def __init__(self, step_reward: float = -1.0, goal: Tuple[int, int] = (0, 3)):
19+
self.n = 4
2520
self.step_reward = float(step_reward)
26-
self.goal_reward = float(goal_reward)
27-
self.self_transition_on_wall = bool(self_transition_on_wall)
2821
self.goal = tuple(goal)
2922

30-
self._S = [(r, c) for r in range(self.h) for c in range(self.w)]
31-
self._A = list(self.ACTIONS)
32-
self._si = {s: i for i, s in enumerate(self._S)}
33-
self._ai = {a: i for i, a in enumerate(self._A)}
34-
35-
# Precompute transition (P) and reward (R) tensors
36-
self._P = np.zeros((len(self._S), len(self._A), len(self._S)), dtype=float)
37-
self._R = np.zeros_like(self._P)
38-
self._build_PR()
39-
40-
# --- Public API ---
41-
def states(self) -> List[Tuple[int, int]]:
42-
return list(self._S)
43-
44-
def actions(self) -> List[str]:
45-
return list(self._A)
46-
47-
def state_index(self, r: int, c: int) -> int:
48-
return self._si[(r, c)]
49-
50-
def action_index(self, name: str) -> int:
51-
return self._ai[name]
52-
53-
def P_tensor(self) -> np.ndarray:
54-
return self._P.copy()
55-
56-
def R_tensor(self) -> np.ndarray:
57-
return self._R.copy()
58-
59-
# --- Internal helpers ---
60-
def _in_bounds(self, r: int, c: int) -> bool:
61-
return 0 <= r < self.h and 0 <= c < self.w
62-
63-
def _next_state(self, s: Tuple[int, int], a: str) -> Tuple[int, int]:
64-
if s == self.goal:
65-
return s # terminal
66-
r, c = s
67-
if a == "up":
68-
nr, nc = r - 1, c
69-
elif a == "right":
70-
nr, nc = r, c + 1
71-
elif a == "down":
72-
nr, nc = r + 1, c
73-
elif a == "left":
74-
nr, nc = r, c - 1
75-
else:
76-
raise ValueError(a)
77-
if self._in_bounds(nr, nc):
78-
return (nr, nc)
79-
return s if self.self_transition_on_wall else s
80-
81-
def _build_PR(self):
82-
for si, s in enumerate(self._S):
83-
for ai, a in enumerate(self._A):
84-
s2 = self._next_state(s, a)
85-
s2i = self._si[s2]
86-
self._P[si, ai, s2i] = 1.0
87-
# Reward per step; terminal has value 0 so step reward applies on entry
88-
r = self.step_reward
89-
self._R[si, ai, s2i] = r
23+
self.S = [(i, j) for i in range(self.n) for j in range(self.n)]
24+
self.s2i = {s: k for k, s in enumerate(self.S)} # (i,j) -> idx
25+
self.i2s = {k: s for k, s in enumerate(self.S)} # idx -> (i,j)
26+
self.A = list(range(len(ACTIONS)))
27+
28+
self.terminal = self.s2i[self.goal]
29+
self.num_states = len(self.S)
30+
self.num_actions = len(self.A)
31+
32+
self.P: Dict[int, Dict[int, List[Transition]]] = self._build_P()
33+
34+
def _in_bounds(self, i: int, j: int) -> bool:
35+
return 0 <= i < self.n and 0 <= j < self.n
36+
37+
def _step_det(self, s_idx: int, a: int) -> Tuple[int, float]:
38+
if s_idx == self.terminal:
39+
return s_idx, 0.0 # absorbing
40+
i, j = self.i2s[s_idx]
41+
di, dj = ACTIONS[a]
42+
ni, nj = i + di, j + dj
43+
if not self._in_bounds(ni, nj):
44+
ni, nj = i, j # bounce to self
45+
sp_idx = self.s2i[(ni, nj)]
46+
r = 0.0 if sp_idx == self.terminal else self.step_reward
47+
return sp_idx, r
48+
49+
def _build_P(self) -> Dict[int, Dict[int, List[Transition]]]:
50+
P: Dict[int, Dict[int, List[Transition]]] = {}
51+
for s in range(self.num_states):
52+
P[s] = {}
53+
for a in self.A:
54+
sp, r = self._step_det(s, a)
55+
P[s][a] = [Transition(s=s, a=a, sp=sp, r=r, p=1.0)]
56+
return P

0 commit comments

Comments
 (0)