Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Shogi] Move reward and terminal #1276

Merged
merged 7 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgx/_src/dwg/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901
if state._x.turn == 1:
if state._x.color == 1:
from pgx._src.games.shogi import _flip

state = ShogiState(_x=_flip(state._x)) # type: ignore
Expand Down
19 changes: 17 additions & 2 deletions pgx/_src/games/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,16 @@ def f(i):


class GameState(NamedTuple):
turn: Array = jnp.int32(0) # 0 or 1
step_count: Array = jnp.int32(0)
color: Array = jnp.int32(0) # 0 or 1
board: Array = INIT_PIECE_BOARD # (81,) flip in turn
hand: Array = jnp.zeros((2, 7), dtype=jnp.int32) # flip in turn
# cache
# Redundant information used only in _is_checked for speeding-up
cache_m2b: Array = -jnp.ones(8, dtype=jnp.int32)
cache_king: Array = jnp.int32(44)
#
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK


class Game:
Expand All @@ -356,6 +359,16 @@ def observe(self, state: GameState) -> Array:

def legal_action_mask(self, state: GameState) -> Array:
return _legal_action_mask(state)

def is_terminal(self, state: GameState) -> Array:
terminated = ~state.legal_action_mask.any()
terminated = terminated | (MAX_TERMINATION_STEPS <= state.step_count)
return terminated

def rewards(self, state: GameState) -> Array:
has_legal_action = state.legal_action_mask.any()
rewards = jnp.float32([[-1.0, 1.0], [1.0, -1.0]])[state.color]
return jax.lax.select(has_legal_action, jnp.zeros(2, dtype=jnp.float32), rewards)


class Action(NamedTuple):
Expand Down Expand Up @@ -437,7 +450,9 @@ def _step(state: GameState, action: Array) -> GameState:
state = jax.lax.cond(a.is_drop, _step_drop, _step_move, *(state, a))
# flip state
state = _flip(state)
return state._replace(turn=(state.turn + 1) % 2)
state = state._replace(color=(state.color + 1) % 2, step_count=state.step_count + 1)
state = state._replace(legal_action_mask=_legal_action_mask(state))
return state


def _step_move(state: GameState, action: Action) -> GameState:
Expand Down
23 changes: 14 additions & 9 deletions pgx/experimental/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def to_sfen(state):

"""
# NOTE: input must be flipped if white turn
state = state if state._x.turn % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore
state = state if state._x.color % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore

pb = np.rot90(state._x.board.reshape((9, 9)), k=3)
sfen = ""
Expand Down Expand Up @@ -63,7 +63,7 @@ def to_sfen(state):
else:
sfen += " "
# Turn
if state._x.turn == 0:
if state._x.color == 0:
sfen += "b "
else:
sfen += "w "
Expand All @@ -85,12 +85,12 @@ def to_sfen(state):


@jax.jit
def _from_board(turn, piece_board, hand):
def _from_board(color, piece_board, hand):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
state = State(_x=GameState(color=color, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
state = jax.lax.cond(color % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore

Expand All @@ -100,7 +100,7 @@ def from_sfen(sfen):
board_char_dir = ["P", "L", "N", "S", "B", "R", "G", "K", "", "", "", "", "", "", "p", "l", "n", "s", "b", "r", "g", "k"]
hand_char_dir = ["P", "L", "N", "S", "B", "R", "G", "p", "l", "n", "s", "b", "r", "g"]
# fmt: on
board, turn, hand, step_count = sfen.split()
board, color, hand, step_count = sfen.split()
board_ranks = board.split("/")
piece_board = np.zeros(81, dtype=np.int32)
for i in range(9):
Expand Down Expand Up @@ -131,6 +131,11 @@ def from_sfen(sfen):
num_piece = 1
piece_board = np.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = np.reshape(s_hand, (2, 7))
turn = 0 if turn == "b" else 1
turn, piece_board, hand, step_count = turn, piece_board, hand, int(step_count) - 1
return _from_board(turn, piece_board, hand).replace(_step_count=np.int32(step_count)) # type: ignore
color = 0 if color == "b" else 1
color, piece_board, hand, step_count = color, piece_board, hand, int(step_count) - 1
state = _from_board(color, piece_board, hand)
state = state.replace( # type: ignore
_step_count=np.int32(step_count),
_x=state._x._replace(step_count=np.int32(step_count)),
)
return state
30 changes: 7 additions & 23 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, INIT_LEGAL_ACTION_MASK
from pgx._src.games.shogi import GameState, Game, _observe, INIT_LEGAL_ACTION_MASK


TRUE = jnp.bool_(True)
Expand All @@ -35,6 +35,7 @@ class State(core.State):
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK # (27 * 81,)
observation: Array = jnp.zeros((119, 9, 9), dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_player_order: Array = jnp.array([0, 1], dtype=jnp.int32)
_x: GameState = GameState()

@property
Expand All @@ -50,8 +51,8 @@ def __init__(self):

def _init(self, key: PRNGKey) -> State:
state = State()
current_player = jnp.int32(jax.random.bernoulli(key))
return state.replace(current_player=current_player) # type: ignore
player_order = jnp.array([[0, 1], [1, 0]])[jax.random.bernoulli(key).astype(jnp.int32)]
return state.replace(_player_order=player_order) # type: ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
Expand All @@ -60,29 +61,12 @@ def _step(self, state: core.State, action: Array, key) -> State:
x = self._game.step(state._x, action)
state = state.replace( # type: ignore
current_player=(state.current_player + 1) % 2,
terminated=self._game.is_terminal(x),
rewards=self._game.rewards(x)[state._player_order],
legal_action_mask=x.legal_action_mask,
_x=x,
)
del x
legal_action_mask = self._game.legal_action_mask(state._x)
terminated = ~legal_action_mask.any()
# fmt: off
reward = jax.lax.select(
terminated,
jnp.ones(2, dtype=jnp.float32).at[state.current_player].set(-1),
jnp.zeros(2, dtype=jnp.float32),
)
# fmt: on
state = state.replace( # type: ignore
legal_action_mask=legal_action_mask,
terminated=terminated,
rewards=reward,
)
state = jax.lax.cond(
(MAX_TERMINATION_STEPS <= state._step_count),
# end with tie
lambda: state.replace(terminated=TRUE), # type: ignore
lambda: state,
)
return state # type: ignore

def _observe(self, state: core.State, player_id: Array) -> Array:
Expand Down
Loading