Skip to content

Commit

Permalink
[Shogi] Move reward and terminal (#1276)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 4, 2024
1 parent a147ffa commit eece484
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901
if state._x.turn == 1:
if state._x.color == 1:
from pgx._src.games.shogi import _flip

state = ShogiState(_x=_flip(state._x)) # type: ignore
Expand Down
19 changes: 17 additions & 2 deletions pgx/_src/games/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,16 @@ def f(i):


class GameState(NamedTuple):
turn: Array = jnp.int32(0) # 0 or 1
step_count: Array = jnp.int32(0)
color: Array = jnp.int32(0) # 0 or 1
board: Array = INIT_PIECE_BOARD # (81,) flip in turn
hand: Array = jnp.zeros((2, 7), dtype=jnp.int32) # flip in turn
# cache
# Redundant information used only in _is_checked for speeding-up
cache_m2b: Array = -jnp.ones(8, dtype=jnp.int32)
cache_king: Array = jnp.int32(44)
#
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK


class Game:
Expand All @@ -356,6 +359,16 @@ def observe(self, state: GameState) -> Array:

def legal_action_mask(self, state: GameState) -> Array:
return _legal_action_mask(state)

def is_terminal(self, state: GameState) -> Array:
terminated = ~state.legal_action_mask.any()
terminated = terminated | (MAX_TERMINATION_STEPS <= state.step_count)
return terminated

def rewards(self, state: GameState) -> Array:
has_legal_action = state.legal_action_mask.any()
rewards = jnp.float32([[-1.0, 1.0], [1.0, -1.0]])[state.color]
return jax.lax.select(has_legal_action, jnp.zeros(2, dtype=jnp.float32), rewards)


class Action(NamedTuple):
Expand Down Expand Up @@ -437,7 +450,9 @@ def _step(state: GameState, action: Array) -> GameState:
state = jax.lax.cond(a.is_drop, _step_drop, _step_move, *(state, a))
# flip state
state = _flip(state)
return state._replace(turn=(state.turn + 1) % 2)
state = state._replace(color=(state.color + 1) % 2, step_count=state.step_count + 1)
state = state._replace(legal_action_mask=_legal_action_mask(state))
return state


def _step_move(state: GameState, action: Action) -> GameState:
Expand Down
23 changes: 14 additions & 9 deletions pgx/experimental/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def to_sfen(state):
"""
# NOTE: input must be flipped if white turn
state = state if state._x.turn % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore
state = state if state._x.color % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore

pb = np.rot90(state._x.board.reshape((9, 9)), k=3)
sfen = ""
Expand Down Expand Up @@ -63,7 +63,7 @@ def to_sfen(state):
else:
sfen += " "
# Turn
if state._x.turn == 0:
if state._x.color == 0:
sfen += "b "
else:
sfen += "w "
Expand All @@ -85,12 +85,12 @@ def to_sfen(state):


@jax.jit
def _from_board(turn, piece_board, hand):
def _from_board(color, piece_board, hand):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
state = State(_x=GameState(color=color, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
state = jax.lax.cond(color % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore

Expand All @@ -100,7 +100,7 @@ def from_sfen(sfen):
board_char_dir = ["P", "L", "N", "S", "B", "R", "G", "K", "", "", "", "", "", "", "p", "l", "n", "s", "b", "r", "g", "k"]
hand_char_dir = ["P", "L", "N", "S", "B", "R", "G", "p", "l", "n", "s", "b", "r", "g"]
# fmt: on
board, turn, hand, step_count = sfen.split()
board, color, hand, step_count = sfen.split()
board_ranks = board.split("/")
piece_board = np.zeros(81, dtype=np.int32)
for i in range(9):
Expand Down Expand Up @@ -131,6 +131,11 @@ def from_sfen(sfen):
num_piece = 1
piece_board = np.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = np.reshape(s_hand, (2, 7))
turn = 0 if turn == "b" else 1
turn, piece_board, hand, step_count = turn, piece_board, hand, int(step_count) - 1
return _from_board(turn, piece_board, hand).replace(_step_count=np.int32(step_count)) # type: ignore
color = 0 if color == "b" else 1
color, piece_board, hand, step_count = color, piece_board, hand, int(step_count) - 1
state = _from_board(color, piece_board, hand)
state = state.replace( # type: ignore
_step_count=np.int32(step_count),
_x=state._x._replace(step_count=np.int32(step_count)),
)
return state
30 changes: 7 additions & 23 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, INIT_LEGAL_ACTION_MASK
from pgx._src.games.shogi import GameState, Game, _observe, INIT_LEGAL_ACTION_MASK


TRUE = jnp.bool_(True)
Expand All @@ -35,6 +35,7 @@ class State(core.State):
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK # (27 * 81,)
observation: Array = jnp.zeros((119, 9, 9), dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_player_order: Array = jnp.array([0, 1], dtype=jnp.int32)
_x: GameState = GameState()

@property
Expand All @@ -50,8 +51,8 @@ def __init__(self):

def _init(self, key: PRNGKey) -> State:
state = State()
current_player = jnp.int32(jax.random.bernoulli(key))
return state.replace(current_player=current_player) # type: ignore
player_order = jnp.array([[0, 1], [1, 0]])[jax.random.bernoulli(key).astype(jnp.int32)]
return state.replace(_player_order=player_order) # type: ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
Expand All @@ -60,29 +61,12 @@ def _step(self, state: core.State, action: Array, key) -> State:
x = self._game.step(state._x, action)
state = state.replace( # type: ignore
current_player=(state.current_player + 1) % 2,
terminated=self._game.is_terminal(x),
rewards=self._game.rewards(x)[state._player_order],
legal_action_mask=x.legal_action_mask,
_x=x,
)
del x
legal_action_mask = self._game.legal_action_mask(state._x)
terminated = ~legal_action_mask.any()
# fmt: off
reward = jax.lax.select(
terminated,
jnp.ones(2, dtype=jnp.float32).at[state.current_player].set(-1),
jnp.zeros(2, dtype=jnp.float32),
)
# fmt: on
state = state.replace( # type: ignore
legal_action_mask=legal_action_mask,
terminated=terminated,
rewards=reward,
)
state = jax.lax.cond(
(MAX_TERMINATION_STEPS <= state._step_count),
# end with tie
lambda: state.replace(terminated=TRUE), # type: ignore
lambda: state,
)
return state # type: ignore

def _observe(self, state: core.State, player_id: Array) -> Array:
Expand Down

0 comments on commit eece484

Please sign in to comment.