Skip to content

Commit 25cf29f

Browse files
author
puyuan
committed
polish(pu): rename step_index to timestep
1 parent 3dfc1f2 commit 25cf29f

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

lzero/mcts/buffer/game_segment.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
6868

6969
self.action_mask_segment = []
7070
self.to_play_segment = []
71-
self.step_index_segment = []
71+
self.timestep_segment = []
7272

7373
self.target_values = []
7474
self.target_rewards = []
@@ -136,7 +136,7 @@ def append(
136136
reward: np.ndarray,
137137
action_mask: np.ndarray = None,
138138
to_play: int = -1,
139-
step_index: int = 0,
139+
timestep: int = 0,
140140
chance: int = 0,
141141
) -> None:
142142
"""
@@ -149,7 +149,7 @@ def append(
149149

150150
self.action_mask_segment.append(action_mask)
151151
self.to_play_segment.append(to_play)
152-
self.step_index_segment.append(step_index)
152+
self.timestep_segment.append(timestep)
153153

154154
if self.use_ture_chance_label_in_chance_encoder:
155155
self.chance_segment.append(chance)
@@ -300,7 +300,7 @@ def game_segment_to_array(self) -> None:
300300

301301
self.action_mask_segment = np.array(self.action_mask_segment)
302302
self.to_play_segment = np.array(self.to_play_segment)
303-
self.step_index_segment = np.array(self.step_index_segment)
303+
self.timestep_segment = np.array(self.timestep_segment)
304304

305305
if self.use_ture_chance_label_in_chance_encoder:
306306
self.chance_segment = np.array(self.chance_segment)
@@ -322,7 +322,7 @@ def reset(self, init_observations: np.ndarray) -> None:
322322

323323
self.action_mask_segment = []
324324
self.to_play_segment = []
325-
self.step_index_segment = []
325+
self.timestep_segment = []
326326

327327
if self.use_ture_chance_label_in_chance_encoder:
328328
self.chance_segment = []

lzero/policy/unizero.py

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ class UniZeroPolicy(MuZeroPolicy):
182182
update_per_collect=None,
183183
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
184184
replay_ratio=0.25,
185+
reanalyze_ratio=0,
185186
# (int) Minibatch size for one gradient descent.
186187
batch_size=256,
187188
# (str) Optimizer for training policy network.

lzero/worker/muzero_collector.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def collect(self,
360360

361361
action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)}
362362
to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)}
363-
step_index_dict = {i: to_ndarray(init_obs[i]['step_index']) for i in range(env_nums)}
363+
timestep_dict = {i: to_ndarray(init_obs[i]['timestep']) for i in range(env_nums)}
364364
if self.policy_config.use_ture_chance_label_in_chance_encoder:
365365
chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}
366366

@@ -421,11 +421,11 @@ def collect(self,
421421

422422
action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id}
423423
to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id}
424-
step_index_dict = {env_id: step_index_dict[env_id] for env_id in ready_env_id}
424+
timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id}
425425

426426
action_mask = [action_mask_dict[env_id] for env_id in ready_env_id]
427427
to_play = [to_play_dict[env_id] for env_id in ready_env_id]
428-
step_index = [step_index_dict[env_id] for env_id in ready_env_id]
428+
timestep = [timestep_dict[env_id] for env_id in ready_env_id]
429429

430430
if self.policy_config.use_ture_chance_label_in_chance_encoder:
431431
chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id}
@@ -439,13 +439,13 @@ def collect(self,
439439
# Key policy forward step
440440
# ==============================================================
441441
# print(f'ready_env_id:{ready_env_id}')
442-
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, step_index=step_index)
442+
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep)
443443

444444
# Extract relevant policy outputs
445445
actions_with_env_id = {k: v['action'] for k, v in policy_output.items()}
446446
value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()}
447447
pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()}
448-
step_index_dict_with_env_id = {k: v['step_index'] for k, v in policy_output.items()}
448+
timestep_dict_with_env_id = {k: v['timestep'] for k, v in policy_output.items()}
449449

450450
if self.policy_config.sampled_algo:
451451
root_sampled_actions_dict_with_env_id = {
@@ -467,7 +467,7 @@ def collect(self,
467467
actions = {}
468468
value_dict = {}
469469
pred_value_dict = {}
470-
step_index_dict = {}
470+
timestep_dict = {}
471471

472472
if not collect_with_pure_policy:
473473
distributions_dict = {}
@@ -485,7 +485,7 @@ def collect(self,
485485
actions[env_id] = actions_with_env_id.pop(env_id)
486486
value_dict[env_id] = value_dict_with_env_id.pop(env_id)
487487
pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id)
488-
step_index_dict[env_id] = step_index_dict_with_env_id.pop(env_id)
488+
timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id)
489489

490490
if not collect_with_pure_policy:
491491
distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id)
@@ -536,19 +536,19 @@ def collect(self,
536536
if self.policy_config.use_ture_chance_label_in_chance_encoder:
537537
game_segments[env_id].append(
538538
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
539-
to_play_dict[env_id], chance_dict[env_id], step_index_dict[env_id]
539+
to_play_dict[env_id], chance_dict[env_id], timestep_dict[env_id]
540540
)
541541
else:
542542
game_segments[env_id].append(
543543
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
544-
to_play_dict[env_id], step_index_dict[env_id]
544+
to_play_dict[env_id], timestep_dict[env_id]
545545
)
546546

547547
# NOTE: the position of code snippet is very important.
548548
# the obs['action_mask'] and obs['to_play'] are corresponding to the next action
549549
action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
550550
to_play_dict[env_id] = to_ndarray(obs['to_play'])
551-
step_index_dict[env_id] = to_ndarray(obs['step_index'])
551+
timestep_dict[env_id] = to_ndarray(obs['timestep'])
552552
if self.policy_config.use_ture_chance_label_in_chance_encoder:
553553
chance_dict[env_id] = to_ndarray(obs['chance'])
554554

@@ -679,7 +679,7 @@ def collect(self,
679679

680680
action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask'])
681681
to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play'])
682-
step_index_dict[env_id] = to_ndarray(init_obs[env_id]['step_index'])
682+
timestep_dict[env_id] = to_ndarray(init_obs[env_id]['timestep'])
683683
if self.policy_config.use_ture_chance_label_in_chance_encoder:
684684
chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])
685685

zoo/atari/config/atari_unizero_config.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,25 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
1111
# ==============================================================
1212
collector_env_num = 8
1313
game_segment_length = 20
14-
evaluator_env_num = 5
14+
evaluator_env_num = 3
1515
num_simulations = 50
1616
max_env_step = int(5e5)
1717
batch_size = 64
1818
num_unroll_steps = 10
1919
infer_context_length = 4
2020
num_layers = 2
2121
replay_ratio = 0.25
22+
23+
# collector_env_num = 2
24+
# game_segment_length = 20
25+
# evaluator_env_num = 1
26+
# num_simulations = 2
27+
# max_env_step = int(5e5)
28+
# batch_size = 2
29+
# num_unroll_steps = 5
30+
# infer_context_length = 2
31+
# num_layers = 1
32+
# replay_ratio = 0.1
2233
# ==============================================================
2334
# end of the most frequently changed config specified by the user
2435
# ==============================================================
@@ -33,8 +44,8 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
3344
n_evaluator_episode=evaluator_env_num,
3445
manager=dict(shared_memory=False, ),
3546
# TODO: only for debug
36-
# collect_max_episode_steps=int(20),
37-
# eval_max_episode_steps=int(20),
47+
# collect_max_episode_steps=int(50),
48+
# eval_max_episode_steps=int(50),
3849
),
3950
policy=dict(
4051
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000
@@ -68,6 +79,7 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
6879
learning_rate=0.0001,
6980
num_simulations=num_simulations,
7081
train_start_after_envsteps=2000,
82+
# train_start_after_envsteps=0, # debug
7183
game_segment_length=game_segment_length,
7284
replay_buffer_size=int(1e6),
7385
eval_freq=int(5e3),
@@ -92,7 +104,7 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
92104
atari_unizero_create_config = EasyDict(atari_unizero_create_config)
93105
create_config = atari_unizero_create_config
94106

95-
main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
107+
main_config.exp_name = f'data_unizero_20250211/{env_id[:-14]}/{env_id[:-14]}_uz_poeembed-mergemain_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
96108
from lzero.entry import train_unizero
97109
train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
98110

zoo/atari/envs/atari_lightzero_env.py

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def step(self, action: int) -> BaseEnvTimestep:
154154
# print(f'self.timestep: {self.timestep}')
155155
observation = self.observe()
156156
if done:
157+
print(f'done in self.timestep: {self.timestep}')
157158
info['eval_episode_return'] = self._eval_episode_return
158159
return BaseEnvTimestep(observation, self.reward, done, info)
159160

0 commit comments

Comments
 (0)