Skip to content

Commit 44da106

Browse files
Fix GridWorld: index-based P and terminal property
1 parent 64083a6 commit 44da106

File tree

1 file changed

+58
-33
lines changed

1 file changed

+58
-33
lines changed

ch2_rl_formulation/gridworld.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,39 @@
11
import numpy as np
2+
from collections import namedtuple
23

3-
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
4+
# Actions: Right, Left, Down, Up
5+
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)]
6+
TR = namedtuple("TR", ["p", "sp", "r", "done"]) # transition record
47

58
class GridWorld4x4:
9+
"""4x4 GridWorld with step cost on every move, including terminal entry.
10+
11+
Convention:
12+
- States are indexed 0..15 in row-major order; i2s maps index->(i,j).
13+
- The goal is absorbing: once there, any action keeps you there with reward 0.
14+
- Each attempted move costs `step_reward` (e.g., -1), including the final move into the goal.
15+
- Bumping into a wall leaves you in place and still incurs the step cost.
16+
- Transition kernel P is indexed by state index: P[s][a] -> [TR(p, sp, r, done)].
17+
"""
18+
619
def __init__(self, step_reward: float = -1.0, goal=(0, 3)):
720
self.n = 4
821
self.goal = tuple(goal)
922
self.step_reward = float(step_reward)
1023

24+
# Enumerate states as (i,j) tuples; provide index mappings
1125
self.S = [(i, j) for i in range(self.n) for j in range(self.n)]
1226
self.s2i = {s: k for k, s in enumerate(self.S)}
1327
self.i2s = {k: s for s, k in self.s2i.items()}
14-
self.A = list(range(len(ACTIONS)))
1528

16-
# Transition kernel: P[s][a] = [(prob, s_next, reward, done)]
17-
self.P = {s: {a: [] for a in self.A} for s in self.S}
18-
self._build_PR()
29+
self.A = list(range(len(ACTIONS))) # 0..3
30+
self._goal_idx = self.s2i[self.goal]
1931

20-
# --- compatibility helpers expected by tests ---
32+
# Transition kernel with integer state indices
33+
self.P = {s_idx: {a: [] for a in self.A} for s_idx in range(self.num_states)}
34+
self._build_P()
35+
36+
# --- properties expected by tests ---
2137
@property
2238
def num_states(self) -> int:
2339
return len(self.S)
@@ -26,41 +42,50 @@ def num_states(self) -> int:
2642
def num_actions(self) -> int:
2743
return len(self.A)
2844

29-
# ---------------- internal helpers ----------------
30-
def _in_bounds(self, i, j):
45+
@property
46+
def terminal(self) -> int:
47+
"""Index of the terminal (goal) state."""
48+
return self._goal_idx
49+
50+
# --- helpers ---
51+
def _in_bounds(self, i, j) -> bool:
3152
return 0 <= i < self.n and 0 <= j < self.n
3253

33-
def _next_state(self, s, a):
54+
def _next_state_tuple(self, s_tuple, a):
3455
di, dj = ACTIONS[a]
35-
i, j = s
56+
i, j = s_tuple
3657
ni, nj = i + di, j + dj
3758
if not self._in_bounds(ni, nj):
38-
return (i, j) # bump: stay in place
59+
return (i, j) # bump into wall: stay
3960
return (ni, nj)
4061

41-
def _build_PR(self):
42-
"""
43-
Convention:
44-
- Every attempted move costs step_reward (e.g., -1), INCLUDING the final move into goal.
45-
- Goal is absorbing: once there, any action yields (goal, 0, done=True).
46-
"""
47-
goal = self.goal
48-
for s in self.S:
49-
if s == goal:
62+
def _build_P(self):
63+
goal_idx = self._goal_idx
64+
65+
for s_idx in range(self.num_states):
66+
s_tuple = self.i2s[s_idx]
67+
68+
if s_idx == goal_idx:
69+
# Absorbing terminal: reward 0 thereafter
5070
for a in self.A:
51-
self.P[s][a] = [(1.0, goal, 0.0, True)]
71+
self.P[s_idx][a] = [TR(1.0, goal_idx, 0.0, True)]
5272
continue
5373

5474
for a in self.A:
55-
s_next = self._next_state(s, a)
56-
done = (s_next == goal)
57-
reward = self.step_reward # charge cost even on final transition
58-
self.P[s][a] = [(1.0, s_next, reward, done)]
59-
60-
# Optional simulator
61-
def step(self, s, a):
62-
trans = self.P[s][a]
63-
probs = [p for p, _, _, _ in trans]
64-
idx = np.random.choice(len(trans), p=probs)
65-
_, s_next, r, done = trans[idx]
66-
return s_next, r, done
75+
next_tuple = self._next_state_tuple(s_tuple, a)
76+
sp_idx = self.s2i[next_tuple]
77+
done = (sp_idx == goal_idx)
78+
# Charge step cost even when entering terminal
79+
r = self.step_reward
80+
self.P[s_idx][a] = [TR(1.0, sp_idx, r, done)]
81+
82+
# Optional index-based step (used by some examples/tests)
83+
def step_idx(self, s_idx: int, a: int):
84+
trans = self.P[s_idx][a]
85+
if len(trans) == 1:
86+
tr = trans[0]
87+
else:
88+
probs = [t.p for t in trans]
89+
idx = np.random.choice(len(trans), p=probs)
90+
tr = trans[idx]
91+
return tr.sp, tr.r, tr.done

0 commit comments

Comments
 (0)