Skip to content

Commit d53a402

Browse files
author
puyuan
committed
feature(pu): add unizero_multitask atari concat_task_embed support
1 parent a04b896 commit d53a402

7 files changed

+137
-67
lines changed

lzero/entry/train_unizero_multitask_segment_ddp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def train_unizero_multitask_segment_ddp(
467467

468468
# 判断是否需要进行评估
469469
# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
470-
if learner.train_iter > 10 or evaluator.should_eval(learner.train_iter): # only for debug
470+
if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug
471471
# if evaluator.should_eval(learner.train_iter):
472472
print('=' * 20)
473473
print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')

lzero/mcts/buffer/game_buffer_muzero.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
467467
m_output = model.initial_inference(m_obs)
468468

469469

470-
if not model.training:
471-
# if not in training, obtain the scalars of the value/reward
472-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
473-
[
474-
m_output.latent_state,
475-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
476-
m_output.policy_logits
477-
]
478-
)
470+
# if not model.training:
471+
# if not in training, obtain the scalars of the value/reward
472+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
473+
[
474+
m_output.latent_state,
475+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
476+
m_output.policy_logits
477+
]
478+
)
479479

480480
network_output.append(m_output)
481481

@@ -595,15 +595,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
595595
else:
596596
m_output = model.initial_inference(m_obs)
597597

598-
if not model.training:
599-
# if not in training, obtain the scalars of the value/reward
600-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
601-
[
602-
m_output.latent_state,
603-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
604-
m_output.policy_logits
605-
]
606-
)
598+
# if not model.training:
599+
# if not in training, obtain the scalars of the value/reward
600+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
601+
[
602+
m_output.latent_state,
603+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
604+
m_output.policy_logits
605+
]
606+
)
607607

608608
network_output.append(m_output)
609609

lzero/mcts/buffer/game_buffer_unizero.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -438,15 +438,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
438438

439439
# =======================================================================
440440

441-
if not model.training:
442-
# if not in training, obtain the scalars of the value/reward
443-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
444-
[
445-
m_output.latent_state,
446-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
447-
m_output.policy_logits
448-
]
449-
)
441+
# if not model.training:
442+
# if not in training, obtain the scalars of the value/reward
443+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
444+
[
445+
m_output.latent_state,
446+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
447+
m_output.policy_logits
448+
]
449+
)
450450

451451
network_output.append(m_output)
452452

@@ -556,15 +556,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
556556

557557
# ======================================================================
558558

559-
if not model.training:
560-
# if not in training, obtain the scalars of the value/reward
561-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
562-
[
563-
m_output.latent_state,
564-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
565-
m_output.policy_logits
566-
]
567-
)
559+
# if not model.training:
560+
# if not in training, obtain the scalars of the value/reward
561+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
562+
[
563+
m_output.latent_state,
564+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
565+
m_output.policy_logits
566+
]
567+
)
568568

569569
network_output.append(m_output)
570570

lzero/model/unizero_model_multitask.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,15 @@ def __init__(
8080
world_model_cfg.norm_type = norm_type
8181
assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action'
8282

83+
if world_model_cfg.task_embed_option == "concat_task_embed":
84+
obs_act_embed_dim = world_model_cfg.embed_dim - 96
85+
else:
86+
obs_act_embed_dim = world_model_cfg.embed_dim
87+
8388
if world_model_cfg.obs_type == 'vector':
8489
self.representation_network = RepresentationNetworkMLP(
8590
observation_shape,
86-
hidden_channels=world_model_cfg.embed_dim,
91+
hidden_channels=obs_act_embed_dim,
8792
layer_num=2,
8893
activation=self.activation,
8994
group_size=world_model_cfg.group_size,
@@ -109,7 +114,7 @@ def __init__(
109114
self.downsample,
110115
activation=self.activation,
111116
norm_type=norm_type,
112-
embedding_dim=world_model_cfg.embed_dim,
117+
embedding_dim=obs_act_embed_dim,
113118
group_size=world_model_cfg.group_size,
114119
))
115120
# self.representation_network = RepresentationNetworkUniZero(
@@ -138,6 +143,7 @@ def __init__(
138143
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
139144
print('==' * 20)
140145
elif world_model_cfg.obs_type == 'image_memory':
146+
# todo for concat_task_embed
141147
self.representation_network = LatentEncoderForMemoryEnv(
142148
image_shape=(3, 5, 5),
143149
embedding_size=world_model_cfg.embed_dim,

lzero/model/unizero_world_models/world_model_multitask.py

+63-4
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
180180
self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device)
181181
print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}")
182182

183+
print(f'='*20)
184+
print(f"self.obs_act_embed_dim:{self.obs_act_embed_dim}")
185+
print(f'='*20)
186+
187+
183188
# if self.num_experts_in_moe_head == -1:
184189
assert self.num_experts_in_moe_head > 0
185190
if self.use_normal_head:
@@ -647,10 +652,12 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
647652
if self.task_embed_option == "add_task_embed":
648653
obs_embeddings = obs_embeddings + self.task_embeddings
649654
elif self.task_embed_option == "concat_task_embed":
655+
650656
# print(f'=='*20)
651657
# print(f'obs_embeddings.shape:{obs_embeddings.shape}')
652658
# print(f'self.task_embeddings.shape:{self.task_embeddings.shape}')
653659
# print(f'=='*20)
660+
654661
if is_init_infer:
655662
# 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了
656663
# Expand task embeddings to match the sequence shape
@@ -862,21 +869,73 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta
862869
-1)
863870

864871
num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
865-
# act_embeddings = self.act_embedding_table[task_id](act_tokens)
866872
act_embeddings = self.act_embedding_table(act_tokens)
867873

868874
B, L, K, E = obs_embeddings.size()
869-
obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device)
875+
if self.task_embed_option == "concat_task_embed":
876+
# B, L*2, E
877+
obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device)
878+
else:
879+
# B, L*2, E
880+
obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device)
881+
882+
if self.task_embed_option == "concat_task_embed":
883+
# Expand task embeddings to match the sequence shape
884+
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1)
885+
870886

871887
for i in range(L):
872-
# obs = obs_embeddings[:, i, :, :]
873-
obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings
888+
if self.task_embed_option == "add_task_embed":
889+
obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings
890+
elif self.task_embed_option == "concat_task_embed":
891+
obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1)
892+
else:
893+
obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E)
894+
874895
act = act_embeddings[:, i, 0, :].unsqueeze(1)
896+
if self.task_embed_option == "concat_task_embed":
897+
act = torch.cat([act, task_emb_expanded], dim=-1)
898+
875899
obs_act = torch.cat([obs, act], dim=1)
900+
# print(f'obs_act.shape:{obs_act.shape}')
901+
876902
obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act
877903

878904
return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps
879905

906+
907+
#@profile
908+
# def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0):
909+
# """
910+
# Process combined observation embeddings and action tokens.
911+
912+
# Arguments:
913+
# - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens.
914+
# - prev_steps (:obj:`torch.Tensor`): Previous steps.
915+
# Returns:
916+
# - torch.Tensor: Combined observation and action embeddings with position information added.
917+
# """
918+
# obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens']
919+
# if len(obs_embeddings.shape) == 3:
920+
# obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens,
921+
# -1)
922+
923+
# num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
924+
# # act_embeddings = self.act_embedding_table[task_id](act_tokens)
925+
# act_embeddings = self.act_embedding_table(act_tokens)
926+
927+
# B, L, K, E = obs_embeddings.size()
928+
# obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device)
929+
930+
# for i in range(L):
931+
# # obs = obs_embeddings[:, i, :, :]
932+
# obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings
933+
# act = act_embeddings[:, i, 0, :].unsqueeze(1)
934+
# obs_act = torch.cat([obs, act], dim=1)
935+
# obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act
936+
937+
# return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps
938+
880939
#@profile
881940
def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=0):
882941
"""

lzero/policy/unizero_multitask.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -948,11 +948,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
948948
network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id)
949949
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
950950

951-
if not self._eval_model.training:
952-
# if not in training, obtain the scalars of the value/reward
953-
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
954-
latent_state_roots = latent_state_roots.detach().cpu().numpy()
955-
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
951+
# if not self._eval_model.training:
952+
# if not in training, obtain the scalars of the value/reward
953+
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
954+
latent_state_roots = latent_state_roots.detach().cpu().numpy()
955+
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
956956

957957
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)]
958958
if self._cfg.mcts_ctree:

zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
1515
n_evaluator_episode=evaluator_env_num,
1616
manager=dict(shared_memory=False),
1717
full_action_space=True,
18-
# collect_max_episode_steps=int(5e3),
19-
# eval_max_episode_steps=int(5e3),
18+
collect_max_episode_steps=int(5e3),
19+
eval_max_episode_steps=int(5e3),
2020
# ===== only for debug =====
21-
collect_max_episode_steps=int(30),
22-
eval_max_episode_steps=int(30),
21+
# collect_max_episode_steps=int(30),
22+
# eval_max_episode_steps=int(30),
2323
),
2424
policy=dict(
25-
use_moco=False, # ==============TODO==============
2625
multi_gpu=True, # Very important for ddp
26+
only_use_moco_stats=False,
27+
use_moco=False, # ==============TODO==============
28+
# use_moco=True, # ==============TODO==============
2729
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))),
2830
grad_correct_params=dict(
2931
MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0,
@@ -41,11 +43,13 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
4143
world_model_cfg=dict(
4244
continuous_action_space=False,
4345

44-
task_embed_option=None, # ==============TODO: none ==============
45-
use_task_embed=False, # ==============TODO==============
46-
use_shared_projection=False,
47-
46+
# task_embed_option=None, # ==============TODO: none ==============
47+
# use_task_embed=False, # ==============TODO==============
4848

49+
task_embed_option='concat_task_embed', # ==============TODO: none ==============
50+
use_task_embed=True, # ==============TODO==============
51+
52+
use_shared_projection=False,
4953
max_blocks=num_unroll_steps,
5054
max_tokens=2 * num_unroll_steps,
5155
context_length=2 * infer_context_length,
@@ -105,7 +109,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
105109
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,
106110
num_segments, total_batch_size):
107111
configs = []
108-
exp_name_prefix = f'data_unizero_atari_mt_20250212_debug/atari_{len(env_id_list)}games_bs64_brf{buffer_reanalyze_freq}_seed{seed}/'
112+
exp_name_prefix = f'data_unizero_atari_mt_20250217/atari_{len(env_id_list)}games_concattaskembed_bs64_brf{buffer_reanalyze_freq}_seed{seed}_dev-uz-mz-mt-cont/'
113+
# exp_name_prefix = f'data_unizero_atari_mt_20250217/atari_{len(env_id_list)}games_notaskembed_bs64_brf{buffer_reanalyze_freq}_seed{seed}_dev-uz-mz-mt-cont/'
109114

110115
for task_id, env_id in enumerate(env_id_list):
111116
config = create_config(
@@ -136,7 +141,7 @@ def create_env_manager():
136141
Overview:
137142
This script should be executed with <nproc_per_node> GPUs.
138143
Run the following command to launch the script:
139-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
144+
python -m torch.distributed.launch --nproc_per_node=4 --master_port=29502 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
140145
torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
141146
"""
142147

@@ -161,8 +166,8 @@ def create_env_manager():
161166
reanalyze_ratio = 0.0
162167
total_batch_size = 512
163168

164-
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
165-
# batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
169+
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
170+
batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
166171

167172
num_unroll_steps = 10
168173
infer_context_length = 4
@@ -172,12 +177,12 @@ def create_env_manager():
172177
reanalyze_partition = 0.75
173178

174179
# ======== TODO: only for debug ========
175-
collector_env_num = 2
176-
num_segments = 2
177-
n_episode = 2
178-
evaluator_env_num = 2
179-
num_simulations = 2
180-
batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
180+
# collector_env_num = 2
181+
# num_segments = 2
182+
# n_episode = 2
183+
# evaluator_env_num = 2
184+
# num_simulations = 1
185+
# batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
181186

182187

183188
for seed in [0]:

0 commit comments

Comments
 (0)