|
| 1 | +from __future__ import annotations |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Dict, Tuple, List |
1 | 4 | import numpy as np |
2 | | -from typing import List, Tuple |
3 | 5 |
|
4 | | -class GridWorld4x4: |
5 | | - """ |
6 | | - 4x4 deterministic GridWorld for Chapter 2. |
7 | | - - Step reward: -1 |
8 | | - - Terminal goal: (0,3) top-right, value 0, no outgoing actions |
9 | | - - Actions: up, right, down, left |
10 | | - - Deterministic transitions; hitting a wall = self-transition |
11 | | - """ |
| 6 | +ACTIONS: List[Tuple[int, int]] = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U |
12 | 7 |
|
13 | | - ACTIONS = ("up", "right", "down", "left") |
| 8 | +@dataclass(frozen=True) |
| 9 | +class Transition: |
| 10 | + s: int |
| 11 | + a: int |
| 12 | + sp: int |
| 13 | + r: float |
| 14 | + p: float |
14 | 15 |
|
15 | | - def __init__( |
16 | | - self, |
17 | | - step_reward: float = -1.0, |
18 | | - goal_reward: float = 0.0, |
19 | | - deterministic: bool = True, |
20 | | - self_transition_on_wall: bool = True, |
21 | | - goal: Tuple[int, int] = (0, 3), |
22 | | - ): |
23 | | - assert deterministic, "Only deterministic dynamics supported in this minimal build." |
24 | | - self.h, self.w = 4, 4 |
| 16 | +class GridWorld4x4: |
| 17 | + """Deterministic 4x4 gridworld with an absorbing goal (reward 0).""" |
| 18 | + def __init__(self, step_reward: float = -1.0, goal: Tuple[int, int] = (0, 3)): |
| 19 | + self.n = 4 |
25 | 20 | self.step_reward = float(step_reward) |
26 | | - self.goal_reward = float(goal_reward) |
27 | | - self.self_transition_on_wall = bool(self_transition_on_wall) |
28 | 21 | self.goal = tuple(goal) |
29 | 22 |
|
30 | | - self._S = [(r, c) for r in range(self.h) for c in range(self.w)] |
31 | | - self._A = list(self.ACTIONS) |
32 | | - self._si = {s: i for i, s in enumerate(self._S)} |
33 | | - self._ai = {a: i for i, a in enumerate(self._A)} |
34 | | - |
35 | | - # Precompute transition (P) and reward (R) tensors |
36 | | - self._P = np.zeros((len(self._S), len(self._A), len(self._S)), dtype=float) |
37 | | - self._R = np.zeros_like(self._P) |
38 | | - self._build_PR() |
39 | | - |
40 | | - # --- Public API --- |
41 | | - def states(self) -> List[Tuple[int, int]]: |
42 | | - return list(self._S) |
43 | | - |
44 | | - def actions(self) -> List[str]: |
45 | | - return list(self._A) |
46 | | - |
47 | | - def state_index(self, r: int, c: int) -> int: |
48 | | - return self._si[(r, c)] |
49 | | - |
50 | | - def action_index(self, name: str) -> int: |
51 | | - return self._ai[name] |
52 | | - |
53 | | - def P_tensor(self) -> np.ndarray: |
54 | | - return self._P.copy() |
55 | | - |
56 | | - def R_tensor(self) -> np.ndarray: |
57 | | - return self._R.copy() |
58 | | - |
59 | | - # --- Internal helpers --- |
60 | | - def _in_bounds(self, r: int, c: int) -> bool: |
61 | | - return 0 <= r < self.h and 0 <= c < self.w |
62 | | - |
63 | | - def _next_state(self, s: Tuple[int, int], a: str) -> Tuple[int, int]: |
64 | | - if s == self.goal: |
65 | | - return s # terminal |
66 | | - r, c = s |
67 | | - if a == "up": |
68 | | - nr, nc = r - 1, c |
69 | | - elif a == "right": |
70 | | - nr, nc = r, c + 1 |
71 | | - elif a == "down": |
72 | | - nr, nc = r + 1, c |
73 | | - elif a == "left": |
74 | | - nr, nc = r, c - 1 |
75 | | - else: |
76 | | - raise ValueError(a) |
77 | | - if self._in_bounds(nr, nc): |
78 | | - return (nr, nc) |
79 | | - return s if self.self_transition_on_wall else s |
80 | | - |
81 | | - def _build_PR(self): |
82 | | - for si, s in enumerate(self._S): |
83 | | - for ai, a in enumerate(self._A): |
84 | | - s2 = self._next_state(s, a) |
85 | | - s2i = self._si[s2] |
86 | | - self._P[si, ai, s2i] = 1.0 |
87 | | - # Reward per step; terminal has value 0 so step reward applies on entry |
88 | | - r = self.step_reward |
89 | | - self._R[si, ai, s2i] = r |
| 23 | + self.S = [(i, j) for i in range(self.n) for j in range(self.n)] |
| 24 | + self.s2i = {s: k for k, s in enumerate(self.S)} # (i,j) -> idx |
| 25 | + self.i2s = {k: s for k, s in enumerate(self.S)} # idx -> (i,j) |
| 26 | + self.A = list(range(len(ACTIONS))) |
| 27 | + |
| 28 | + self.terminal = self.s2i[self.goal] |
| 29 | + self.num_states = len(self.S) |
| 30 | + self.num_actions = len(self.A) |
| 31 | + |
| 32 | + self.P: Dict[int, Dict[int, List[Transition]]] = self._build_P() |
| 33 | + |
| 34 | + def _in_bounds(self, i: int, j: int) -> bool: |
| 35 | + return 0 <= i < self.n and 0 <= j < self.n |
| 36 | + |
| 37 | + def _step_det(self, s_idx: int, a: int) -> Tuple[int, float]: |
| 38 | + if s_idx == self.terminal: |
| 39 | + return s_idx, 0.0 # absorbing |
| 40 | + i, j = self.i2s[s_idx] |
| 41 | + di, dj = ACTIONS[a] |
| 42 | + ni, nj = i + di, j + dj |
| 43 | + if not self._in_bounds(ni, nj): |
| 44 | + ni, nj = i, j # bounce to self |
| 45 | + sp_idx = self.s2i[(ni, nj)] |
| 46 | + r = 0.0 if sp_idx == self.terminal else self.step_reward |
| 47 | + return sp_idx, r |
| 48 | + |
| 49 | + def _build_P(self) -> Dict[int, Dict[int, List[Transition]]]: |
| 50 | + P: Dict[int, Dict[int, List[Transition]]] = {} |
| 51 | + for s in range(self.num_states): |
| 52 | + P[s] = {} |
| 53 | + for a in self.A: |
| 54 | + sp, r = self._step_det(s, a) |
| 55 | + P[s][a] = [Transition(s=s, a=a, sp=sp, r=r, p=1.0)] |
| 56 | + return P |
0 commit comments