Skip to content

Commit 4a0d1c6

Browse files
author
puyuan
committed
fix(pu): fix action_mask all-zero bug
1 parent 581310b commit 4a0d1c6

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

lzero/mcts/buffer/game_buffer_muzero.py

+1
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def _compute_target_policy_non_reanalyzed(
737737
for index, legal_action in enumerate(legal_actions[policy_index]):
738738
# only the action in ``legal_action`` the policy logits is nonzero
739739
# policy_tmp[legal_action] = distributions[index]
740+
# import ipdb;ipdb.set_trace()
740741
try:
741742
policy_tmp[legal_action] = distributions[index]
742743
except Exception as e:

lzero/policy/unizero.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _init_learn(self) -> None:
343343
wandb.watch(self._learn_model.representation_network, log="all")
344344

345345
# TODO: ========
346-
self.accumulation_steps = 1 # 累积的步数
346+
self.accumulation_steps = self._cfg.accumulation_steps # 累积的步数
347347

348348
# @profile
349349
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:

zoo/jericho/configs/jericho_unizero_config.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ def main(env_id='detective.z5', seed=0):
1616

1717
evaluator_env_num = 2
1818
num_simulations = 50
19-
max_env_step = int(10e6)
19+
max_env_step = int(1e6)
2020

2121
# proj train
2222
collector_env_num = 4
2323
n_episode = 4
2424
batch_size = 16
2525
num_unroll_steps = 10
2626
infer_context_length = 4
27+
# num_unroll_steps = 5
28+
# infer_context_length = 2
2729

2830
# all train
2931
# collector_env_num = 2
@@ -93,6 +95,7 @@ def main(env_id='detective.z5', seed=0):
9395
use_wandb=False,
9496
learn=dict(learner=dict(
9597
hook=dict(save_ckpt_after_iter=1000000, ), ), ),
98+
accumulation_steps=1, # TODO
9699
model=dict(
97100
observation_shape=512,
98101
action_space_size=action_space_size,
@@ -136,7 +139,7 @@ def main(env_id='detective.z5', seed=0):
136139
# game_segment_length=game_segment_length,
137140
# replay_buffer_size=int(1e6),
138141
replay_buffer_size=int(1e5),
139-
eval_freq=int(5e3),
142+
eval_freq=int(1e4),
140143
collector_env_num=collector_env_num,
141144
evaluator_env_num=evaluator_env_num,
142145
# ============= The key different params for reanalyze =============

zoo/jericho/envs/jericho_env.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,18 @@ def prepare_obs(self, obs, return_str: bool = False):
9494
full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask
9595

9696

97-
if len(available_actions) <= self.max_action_num:
97+
if len(available_actions) == 0:
98+
# 避免action_maks全为0导致mcts报segment fault的错误
99+
action_mask = [1] + [0] * (self.max_action_num - 1)
100+
elif 0<len(available_actions) <= self.max_action_num:
98101
action_mask = [1] * len(available_actions) + [0] * (self.max_action_num - len(available_actions))
99102
elif len(available_actions) == self.max_action_num:
100103
action_mask = [1] * len(available_actions)
101104
else:
102105
action_mask = [1] * self.max_action_num
103106

107+
# action_mask = [0] * self.max_action_num
108+
104109
action_mask = np.array(action_mask, dtype=np.int8)
105110

106111
if return_str: # TODO: unizero需要加上'to_play'===============

0 commit comments

Comments
 (0)