Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hex] Enhance terminal computation #1290

Merged
merged 5 commits into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions pgx/_src/games/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class GameState(NamedTuple):
# .
# [110, 111, 112, ..., 119, 120]]
board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)
terminated: Array = FALSE

@property
def color(self) -> Array:
Expand Down Expand Up @@ -62,22 +63,22 @@ def legal_action_mask(self, state: GameState) -> Array:
return jnp.append(state.board == 0, state.step_count == 1)

def is_terminal(self, state: GameState) -> Array:
top, bottom = jax.lax.cond(
state.color == 0,
lambda: (state.board[::self.size], state.board[self.size - 1 :: self.size]),
lambda: (state.board[:self.size], state.board[-self.size:]),
)

def check_same_id_exist(_id):
return (_id < 0) & (_id == bottom).any()

return jax.vmap(check_same_id_exist)(top).any()

return state.terminated

# def rewards(self, state: GameState) -> Array:
# ...


def _is_terminal(state: GameState, action: Array, size: int) -> Array:
top, bottom = jax.lax.cond(
state.color == 0,
lambda: (state.board[::size], state.board[size - 1 :: size]),
lambda: (state.board[:size], state.board[-size:]),
)
target_id = state.board[action] # target_id != 0
return (top == target_id).any() & (bottom == target_id).any()


def _step(state: GameState, action: Array, size: int) -> GameState:
set_place_id = action + 1
board = state.board.at[action].set(set_place_id)
Expand All @@ -92,10 +93,12 @@ def merge(i, b):
)

board = jax.lax.fori_loop(0, 6, merge, board)
return state._replace(

state = state._replace(
step_count=state.step_count + 1,
board=board * -1,
)
return state._replace(terminated=_is_terminal(state, action, size))


def _swap(state: GameState, size: int) -> GameState:
Expand Down