Skip to content

Commit 04c4a68

Browse files
author
puyuan
committed
polish(pu): adapt to discrete action space env like atari
1 parent 1860f11 commit 04c4a68

10 files changed

+112
-74
lines changed

lzero/entry/train_muzero_multitask_segment_ddp.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ def train_muzero_multitask_segment_ddp(
380380
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
381381

382382
# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
383-
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
384-
383+
if learner.train_iter > 1 and evaluator.should_eval(learner.train_iter):
385384
print('=' * 20)
386385
print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...')
387386

lzero/mcts/buffer/game_buffer_muzero.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def __init__(self, cfg: dict):
6464
if hasattr(self._cfg, 'task_id'):
6565
self.task_id = self._cfg.task_id
6666
print(f"Task ID is set to {self.task_id}.")
67-
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
67+
try:
68+
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
69+
except Exception as e:
70+
self.action_space_size = self._cfg.model.action_space_size
6871

6972
else:
7073
self.task_id = None

lzero/mcts/buffer/game_buffer_unizero.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def __init__(self, cfg: dict):
5252
if hasattr(self._cfg, 'task_id'):
5353
self.task_id = self._cfg.task_id
5454
print(f"Task ID is set to {self.task_id}.")
55-
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
56-
55+
try:
56+
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
57+
except Exception as e:
58+
self.action_space_size = self._cfg.model.action_space_size
5759
else:
5860
self.task_id = None
5961
print("No task_id found in configuration. Task ID is set to None.")

lzero/mcts/buffer/game_segment.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,15 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
6060
# image obs input, e.g. atari environments
6161
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
6262
else:
63-
if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1:
64-
# for vector obs input, e.g. classical control and box2d environments
65-
self.zero_obs_shape = config.model.observation_shape_list[task_id]
66-
elif len(config.model.observation_shape_list[task_id]) == 3:
67-
# image obs input, e.g. atari environments
68-
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
63+
if hasattr(config.model, "observation_shape_list"):
64+
if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1:
65+
# for vector obs input, e.g. classical control and box2d environments
66+
self.zero_obs_shape = config.model.observation_shape_list[task_id]
67+
elif len(config.model.observation_shape_list[task_id]) == 3:
68+
# image obs input, e.g. atari environments
69+
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
70+
else:
71+
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
6972

7073
self.obs_segment = []
7174
self.action_segment = []

lzero/model/unizero_model_multitask.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ def __init__(
112112
embedding_dim=world_model_cfg.embed_dim,
113113
group_size=world_model_cfg.group_size,
114114
))
115+
# self.representation_network = RepresentationNetworkUniZero(
116+
# observation_shape,
117+
# num_res_blocks,
118+
# num_channels,
119+
# self.downsample,
120+
# activation=self.activation,
121+
# norm_type=norm_type,
122+
# embedding_dim=world_model_cfg.embed_dim,
123+
# group_size=world_model_cfg.group_size,
124+
# )
115125
# TODO: we should change the output_shape to the real observation shape
116126
# self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))
117127

@@ -187,8 +197,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
187197
latent state, W_ is the width of latent state.
188198
"""
189199
batch_size = obs_batch.size(0)
190-
print('=here 5='*20)
191-
import ipdb; ipdb.set_trace()
200+
# print('=here 5='*20)
201+
# import ipdb; ipdb.set_trace()
192202
obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch}
193203
_, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id)
194204
latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value

lzero/model/unizero_world_models/tokenizer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten
9696
obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask
9797
# obs_embeddings = self.encoder[task_id](x)
9898
except Exception as e:
99-
print(e)
100-
obs_embeddings = self.encoder(x) # TODO: for memory env
99+
# print(e)
100+
obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env
101101

102102
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
103103
elif len(shape) == 5:
@@ -106,7 +106,7 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten
106106
try:
107107
obs_embeddings = self.encoder[task_id](x)
108108
except Exception as e:
109-
obs_embeddings = self.encoder(x) # TODO: for memory env
109+
obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env
110110
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
111111
else:
112112
raise ValueError(f"Invalid input shape: {shape}")

lzero/model/unizero_world_models/world_model_multitask.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
681681
if len(act_tokens.shape) == 3:
682682
act_tokens = act_tokens.squeeze(1)
683683
num_steps = act_tokens.size(1)
684-
if self.task_num >= 1:
684+
if self.task_num >= 1 and self.continuous_action_space:
685685
act_embeddings = self.act_embedding_table[task_id](act_tokens)
686686
else:
687687
act_embeddings = self.act_embedding_table(act_tokens)
@@ -862,7 +862,8 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta
862862
-1)
863863

864864
num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
865-
act_embeddings = self.act_embedding_table[task_id](act_tokens)
865+
# act_embeddings = self.act_embedding_table[task_id](act_tokens)
866+
act_embeddings = self.act_embedding_table(act_tokens)
866867

867868
B, L, K, E = obs_embeddings.size()
868869
obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device)

lzero/policy/unizero_multitask.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def _forward_collect(
892892
if active_collect_env_num < self.collector_env_num:
893893
print('==========collect_forward============')
894894
print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}')
895-
self._reset_collect(reset_init_data=True)
895+
self._reset_collect(reset_init_data=True, task_id=task_id)
896896

897897
return output
898898

@@ -1001,7 +1001,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
10011001
return output
10021002

10031003
#@profile
1004-
def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None:
1004+
def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None:
10051005
"""
10061006
Overview:
10071007
This method resets the collection process for a specific environment. It clears caches and memory
@@ -1085,21 +1085,21 @@ def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_dat
10851085
- reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
10861086
"""
10871087
if reset_init_data:
1088-
if task_id is not None:
1089-
self.last_batch_obs_eval = initialize_zeros_batch(
1090-
self._cfg.model.observation_shape_list[task_id],
1091-
self._cfg.evaluator_env_num,
1092-
self._cfg.device
1093-
)
1094-
print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
1095-
1096-
else:
1097-
self.last_batch_obs_eval = initialize_zeros_batch(
1098-
self._cfg.model.observation_shape,
1099-
self._cfg.evaluator_env_num,
1100-
self._cfg.device
1101-
)
1102-
print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
1088+
# if task_id is not None:
1089+
# self.last_batch_obs_eval = initialize_zeros_batch(
1090+
# self._cfg.model.observation_shape_list[task_id],
1091+
# self._cfg.evaluator_env_num,
1092+
# self._cfg.device
1093+
# )
1094+
# print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
1095+
1096+
# else:
1097+
self.last_batch_obs_eval = initialize_zeros_batch(
1098+
self._cfg.model.observation_shape,
1099+
self._cfg.evaluator_env_num,
1100+
self._cfg.device
1101+
)
1102+
print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
11031103

11041104
self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
11051105

zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
2222
# eval_max_episode_steps=int(30),
2323
),
2424
policy=dict(
25+
use_moco=False, # ==============TODO==============
2526
multi_gpu=True, # Very important for ddp
2627
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))),
2728
grad_correct_params=dict(
@@ -37,24 +38,39 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
3738
num_res_blocks=2,
3839
num_channels=256,
3940
world_model_cfg=dict(
41+
42+
task_embed_option=None, # ==============TODO: none ==============
43+
use_task_embed=False, # ==============TODO==============
44+
use_shared_projection=False,
45+
46+
4047
max_blocks=num_unroll_steps,
4148
max_tokens=2 * num_unroll_steps,
4249
context_length=2 * infer_context_length,
4350
device='cuda',
4451
action_space_size=action_space_size,
4552
# batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存
46-
num_layers=12,
47-
num_heads=24,
53+
# num_layers=12,
54+
# num_heads=24,
55+
56+
num_layers=8,
57+
num_heads=8,
58+
4859
embed_dim=768,
4960
obs_type='image',
5061
env_num=8,
5162
task_num=len(env_id_list),
5263
use_normal_head=True,
5364
use_softmoe_head=False,
65+
use_moe_head=False,
66+
num_experts_in_moe_head=4,
5467
moe_in_transformer=False,
68+
multiplication_moe_in_transformer=False,
5569
num_experts_of_moe_in_transformer=4,
5670
),
5771
),
72+
use_task_exploitation_weight=False, # TODO
73+
task_complexity_weight=False, # TODO
5874
total_batch_size=total_batch_size,
5975
allocated_batch_sizes=False,
6076
train_start_after_envsteps=int(0),
@@ -87,7 +103,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
87103
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,
88104
num_segments, total_batch_size):
89105
configs = []
90-
exp_name_prefix = f'data_unizero_mt_ddp-8gpu/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/'
106+
exp_name_prefix = f'data_unizero_atari_mt_20250212/atari_{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/'
91107

92108
for task_id, env_id in enumerate(env_id_list):
93109
config = create_config(
@@ -118,7 +134,7 @@ def create_env_manager():
118134
Overview:
119135
This script should be executed with <nproc_per_node> GPUs.
120136
Run the following command to launch the script:
121-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
137+
python -m torch.distributed.launch --nproc_per_node=5 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
122138
torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
123139
"""
124140

0 commit comments

Comments
 (0)