Skip to content

Commit

Permalink
polish(pu): adapt to discrete action space env like atari
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan committed Feb 12, 2025
1 parent 1860f11 commit 04c4a68
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 74 deletions.
3 changes: 1 addition & 2 deletions lzero/entry/train_muzero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ def train_muzero_multitask_segment_ddp(
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):

if learner.train_iter > 1 and evaluator.should_eval(learner.train_iter):
print('=' * 20)
print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...')

Expand Down
5 changes: 4 additions & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def __init__(self, cfg: dict):
if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
try:
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
except Exception as e:
self.action_space_size = self._cfg.model.action_space_size

else:
self.task_id = None
Expand Down
6 changes: 4 additions & 2 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def __init__(self, cfg: dict):
if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]

try:
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
except Exception as e:
self.action_space_size = self._cfg.model.action_space_size
else:
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")
Expand Down
15 changes: 9 additions & 6 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
else:
if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape_list[task_id]
elif len(config.model.observation_shape_list[task_id]) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
if hasattr(config.model, "observation_shape_list"):
if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape_list[task_id]
elif len(config.model.observation_shape_list[task_id]) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
else:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
self.action_segment = []
Expand Down
14 changes: 12 additions & 2 deletions lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def __init__(
embedding_dim=world_model_cfg.embed_dim,
group_size=world_model_cfg.group_size,
))
# self.representation_network = RepresentationNetworkUniZero(
# observation_shape,
# num_res_blocks,
# num_channels,
# self.downsample,
# activation=self.activation,
# norm_type=norm_type,
# embedding_dim=world_model_cfg.embed_dim,
# group_size=world_model_cfg.group_size,
# )
# TODO: we should change the output_shape to the real observation shape
# self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))

Expand Down Expand Up @@ -187,8 +197,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
latent state, W_ is the width of latent state.
"""
batch_size = obs_batch.size(0)
print('=here 5='*20)
import ipdb; ipdb.set_trace()
# print('=here 5='*20)
# import ipdb; ipdb.set_trace()
obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch}
_, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id)
latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value
Expand Down
6 changes: 3 additions & 3 deletions lzero/model/unizero_world_models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten
obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask
# obs_embeddings = self.encoder[task_id](x)
except Exception as e:
print(e)
obs_embeddings = self.encoder(x) # TODO: for memory env
# print(e)
obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env

obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
elif len(shape) == 5:
Expand All @@ -106,7 +106,7 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten
try:
obs_embeddings = self.encoder[task_id](x)
except Exception as e:
obs_embeddings = self.encoder(x) # TODO: for memory env
obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
else:
raise ValueError(f"Invalid input shape: {shape}")
Expand Down
5 changes: 3 additions & 2 deletions lzero/model/unizero_world_models/world_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
if len(act_tokens.shape) == 3:
act_tokens = act_tokens.squeeze(1)
num_steps = act_tokens.size(1)
if self.task_num >= 1:
if self.task_num >= 1 and self.continuous_action_space:
act_embeddings = self.act_embedding_table[task_id](act_tokens)
else:
act_embeddings = self.act_embedding_table(act_tokens)
Expand Down Expand Up @@ -862,7 +862,8 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta
-1)

num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
act_embeddings = self.act_embedding_table[task_id](act_tokens)
# act_embeddings = self.act_embedding_table[task_id](act_tokens)
act_embeddings = self.act_embedding_table(act_tokens)

B, L, K, E = obs_embeddings.size()
obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device)
Expand Down
34 changes: 17 additions & 17 deletions lzero/policy/unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _forward_collect(
if active_collect_env_num < self.collector_env_num:
print('==========collect_forward============')
print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}')
self._reset_collect(reset_init_data=True)
self._reset_collect(reset_init_data=True, task_id=task_id)

return output

Expand Down Expand Up @@ -1001,7 +1001,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
return output

#@profile
def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None:
def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None:
"""
Overview:
This method resets the collection process for a specific environment. It clears caches and memory
Expand Down Expand Up @@ -1085,21 +1085,21 @@ def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_dat
- reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
"""
if reset_init_data:
if task_id is not None:
self.last_batch_obs_eval = initialize_zeros_batch(
self._cfg.model.observation_shape_list[task_id],
self._cfg.evaluator_env_num,
self._cfg.device
)
print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)

else:
self.last_batch_obs_eval = initialize_zeros_batch(
self._cfg.model.observation_shape,
self._cfg.evaluator_env_num,
self._cfg.device
)
print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
# if task_id is not None:
# self.last_batch_obs_eval = initialize_zeros_batch(
# self._cfg.model.observation_shape_list[task_id],
# self._cfg.evaluator_env_num,
# self._cfg.device
# )
# print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)

# else:
self.last_batch_obs_eval = initialize_zeros_batch(
self._cfg.model.observation_shape,
self._cfg.evaluator_env_num,
self._cfg.device
)
print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)

self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
# eval_max_episode_steps=int(30),
),
policy=dict(
use_moco=False, # ==============TODO==============
multi_gpu=True, # Very important for ddp
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))),
grad_correct_params=dict(
Expand All @@ -37,24 +38,39 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
num_res_blocks=2,
num_channels=256,
world_model_cfg=dict(

task_embed_option=None, # ==============TODO: none ==============
use_task_embed=False, # ==============TODO==============
use_shared_projection=False,


max_blocks=num_unroll_steps,
max_tokens=2 * num_unroll_steps,
context_length=2 * infer_context_length,
device='cuda',
action_space_size=action_space_size,
# batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存
num_layers=12,
num_heads=24,
# num_layers=12,
# num_heads=24,

num_layers=8,
num_heads=8,

embed_dim=768,
obs_type='image',
env_num=8,
task_num=len(env_id_list),
use_normal_head=True,
use_softmoe_head=False,
use_moe_head=False,
num_experts_in_moe_head=4,
moe_in_transformer=False,
multiplication_moe_in_transformer=False,
num_experts_of_moe_in_transformer=4,
),
),
use_task_exploitation_weight=False, # TODO
task_complexity_weight=False, # TODO
total_batch_size=total_batch_size,
allocated_batch_sizes=False,
train_start_after_envsteps=int(0),
Expand Down Expand Up @@ -87,7 +103,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,
num_segments, total_batch_size):
configs = []
exp_name_prefix = f'data_unizero_mt_ddp-8gpu/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/'
exp_name_prefix = f'data_unizero_atari_mt_20250212/atari_{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/'

for task_id, env_id in enumerate(env_id_list):
config = create_config(
Expand Down Expand Up @@ -118,7 +134,7 @@ def create_env_manager():
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=5 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
"""

Expand Down
Loading

0 comments on commit 04c4a68

Please sign in to comment.