Skip to content

Commit bb5572b

Browse files
author
puyuan
committed
fix(pu): fix model.training bug
1 parent 0d15204 commit bb5572b

File tree

2 files changed

+26
-25
lines changed

2 files changed

+26
-25
lines changed

lzero/mcts/buffer/game_buffer_muzero.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
461461
m_output = model.initial_inference(m_obs)
462462

463463

464-
if not model.training:
465-
# if not in training, obtain the scalars of the value/reward
466-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
467-
[
468-
m_output.latent_state,
469-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
470-
m_output.policy_logits
471-
]
472-
)
464+
# if not model.training:
465+
# if not in training, obtain the scalars of the value/reward
466+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
467+
[
468+
m_output.latent_state,
469+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
470+
m_output.policy_logits
471+
]
472+
)
473473

474474
network_output.append(m_output)
475475

@@ -589,15 +589,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
589589
else:
590590
m_output = model.initial_inference(m_obs)
591591

592-
if not model.training:
593-
# if not in training, obtain the scalars of the value/reward
594-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
595-
[
596-
m_output.latent_state,
597-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
598-
m_output.policy_logits
599-
]
600-
)
592+
# if not model.training:
593+
# if not in training, obtain the scalars of the value/reward
594+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
595+
[
596+
m_output.latent_state,
597+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
598+
m_output.policy_logits
599+
]
600+
)
601601

602602
network_output.append(m_output)
603603

zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
7373
model_path=None,
7474
num_unroll_steps=num_unroll_steps,
7575
game_segment_length=20,
76-
# update_per_collect=80,
77-
update_per_collect=10, # only for debug
76+
update_per_collect=80,
77+
# update_per_collect=10, # only for debug
7878
replay_ratio=0.25,
7979
batch_size=batch_size,
8080
num_segments=num_segments,
@@ -96,7 +96,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
9696
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,
9797
num_segments, total_batch_size):
9898
configs = []
99-
exp_name_prefix = f'data_unizero_atari_mt_20250216/{len(env_id_list)}games_nlayer8_bs64_brf{buffer_reanalyze_freq}_seed{seed}/'
99+
exp_name_prefix = f'data_unizero_atari_mt_20250217/{len(env_id_list)}games_nlayer8_bs64_brf{buffer_reanalyze_freq}_seed{seed}_dev-uz-mz/'
100100

101101
for task_id, env_id in enumerate(env_id_list):
102102
config = create_config(
@@ -164,8 +164,9 @@ def create_env_manager():
164164
# num_segments = 2
165165
# n_episode = 2
166166
# evaluator_env_num = 2
167-
# num_simulations = 2
168-
# batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
167+
# num_simulations = 5
168+
# # batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
169+
# batch_size = [4, 4,4,4]
169170

170171

171172
for seed in [0]:
@@ -175,5 +176,5 @@ def create_env_manager():
175176
num_segments, total_batch_size)
176177

177178
with DDPContext():
178-
# train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step)
179-
train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) # train on the first four tasks, only for debug
179+
train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step)
180+
# train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) # train on the first four tasks, only for debug

0 commit comments

Comments
 (0)