Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hex] Extract game specific attributes #1287

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgx/_src/dwg/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _make_hex_dwg(dwg, state: HexState, config):
GRID_SIZE = config["GRID_SIZE"] / 2 # 六角形の1辺
BOARD_SIZE = int(state._size)
BOARD_SIZE = int(state._x.size)
color_set = config["COLOR_SET"]

# background
Expand Down
6 changes: 3 additions & 3 deletions pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ def _set_config_by_state(self, _state: State): # noqa: C901
from pgx._src.dwg.hex import _make_hex_dwg, four_dig

self.config["GRID_SIZE"] = 30
size = int(jnp.array(_state._size).ravel()[0])
self.config["BOARD_WIDTH"] = four_dig(size * 1.5) # type:ignore
self.config["BOARD_HEIGHT"] = four_dig(size * jnp.sqrt(3) / 2) # type:ignore
size = int(jnp.array(_state._x.size).ravel()[0])
self.config["BOARD_WIDTH"] = four_dig(size * 1.5) # type:ignore
self.config["BOARD_HEIGHT"] = four_dig(size * jnp.sqrt(3) / 2) # type:ignore
self._make_dwg_group = _make_hex_dwg # type:ignore
if (self.config["COLOR_THEME"] is None and self.config["COLOR_THEME"] == "dark") or self.config[
"COLOR_THEME"
Expand Down
58 changes: 33 additions & 25 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand All @@ -25,6 +26,20 @@
TRUE = jnp.bool_(True)


class GameState(NamedTuple):
size: Array = jnp.int32(11)
# 0(black), 1(white)
turn: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
# .
# .
# .
# [110, 111, 112, ..., 119, 120]]
board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
Expand All @@ -34,18 +49,7 @@ class State(core.State):
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(11 * 11 + 1, dtype=jnp.bool_).at[-1].set(FALSE)
_step_count: Array = jnp.int32(0)
# --- Hex specific ---
_size: Array = jnp.int32(11)
# 0(black), 1(white)
_turn: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
# .
# .
# .
# [110, 111, 112, ..., 119, 120]]
_board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
Expand Down Expand Up @@ -89,12 +93,12 @@ def num_players(self) -> int:

def _init(rng: PRNGKey, size: int) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(_size=size, current_player=current_player) # type:ignore
return State(_x=GameState(size=size), current_player=current_player) # type:ignore


def _step(state: State, action: Array, size: int) -> State:
set_place_id = action + 1
board = state._board.at[action].set(set_place_id)
board = state._x.board.at[action].set(set_place_id)
neighbour = _neighbour(action, size)

def merge(i, b):
Expand All @@ -106,7 +110,7 @@ def merge(i, b):
)

board = jax.lax.fori_loop(0, 6, merge, board)
won = _is_game_end(board, size, state._turn)
won = _is_game_end(board, size, state._x.turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
Expand All @@ -115,8 +119,10 @@ def merge(i, b):

state = state.replace( # type:ignore
current_player=1 - state.current_player,
_turn=1 - state._turn,
_board=board * -1,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
rewards=reward,
terminated=won,
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(state._step_count == 1),
Expand All @@ -126,31 +132,33 @@ def merge(i, b):


def _swap(state: State, size: int) -> State:
ix = jnp.nonzero(state._board, size=1)[0]
ix = jnp.nonzero(state._x.board, size=1)[0]
row = ix // size
col = ix % size
swapped_ix = col * size + row
set_place_id = swapped_ix + 1
board = state._board.at[ix].set(0).at[swapped_ix].set(set_place_id)
board = state._x.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state.replace( # type: ignore
current_player=1 - state.current_player,
_turn=1 - state._turn,
_board=board * -1,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(FALSE),
)


def _observe(state: State, player_id: Array, size) -> Array:
board = jax.lax.select(
player_id == state.current_player,
state._board.reshape((size, size)),
-state._board.reshape((size, size)),
state._x.board.reshape((size, size)),
-state._x.board.reshape((size, size)),
)

my_board = board * 1 > 0
opp_board = board * -1 > 0
ones = jnp.ones_like(my_board)
color = jax.lax.select(player_id == state.current_player, state._turn, 1 - state._turn)
color = jax.lax.select(player_id == state.current_player, state._x.turn, 1 - state._x.turn)
color = color * ones
can_swap = state.legal_action_mask[-1] * ones

Expand Down Expand Up @@ -185,4 +193,4 @@ def check_same_id_exist(_id):


def _get_abs_board(state):
return jax.lax.cond(state._turn == 0, lambda: state._board, lambda: state._board * -1)
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)
18 changes: 9 additions & 9 deletions tests/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_merge():
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# fmt:on
assert jnp.all(state._board == expected)
assert jnp.all(state._x.board == expected)


def test_swap():
Expand All @@ -50,26 +50,26 @@ def test_swap():
assert ~state.legal_action_mask[-1]
state = step(state, 1)
state.save_svg("tests/assets/hex/swap_01.svg")
assert (state._board != 0).sum() == 1
assert state._board[1] == -2
assert (state._x.board != 0).sum() == 1
assert state._x.board[1] == -2
assert state.legal_action_mask[-1]
state = step(state, 121) # swap!
state.save_svg("tests/assets/hex/swap_02.svg")
assert (state._board != 0).sum() == 1
assert state._board[11] == -12
assert (state._x.board != 0).sum() == 1
assert state._x.board[11] == -12
assert ~state.legal_action_mask[-1]

key = jax.random.PRNGKey(0)
state = init(key=key)
state = step(state, 0)
state.save_svg("tests/assets/hex/swap_03.svg")
assert (state._board != 0).sum() == 1
assert state._board[0] == -1
assert (state._x.board != 0).sum() == 1
assert state._x.board[0] == -1
assert state.legal_action_mask[-1]
state = step(state, 121) # swap!
state.save_svg("tests/assets/hex/swap_04.svg")
assert (state._board != 0).sum() == 1
assert state._board[0] == -1
assert (state._x.board != 0).sum() == 1
assert state._x.board[0] == -1
assert ~state.legal_action_mask[-1]

key = jax.random.PRNGKey(0)
Expand Down