@@ -22,6 +22,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
22
22
# eval_max_episode_steps=int(30),
23
23
),
24
24
policy = dict (
25
+ use_moco = False , # ==============TODO==============
25
26
multi_gpu = True , # Very important for ddp
26
27
learn = dict (learner = dict (hook = dict (save_ckpt_after_iter = 200000 ))),
27
28
grad_correct_params = dict (
@@ -37,24 +38,39 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
37
38
num_res_blocks = 2 ,
38
39
num_channels = 256 ,
39
40
world_model_cfg = dict (
41
+
42
+ task_embed_option = None , # ==============TODO: none ==============
43
+ use_task_embed = False , # ==============TODO==============
44
+ use_shared_projection = False ,
45
+
46
+
40
47
max_blocks = num_unroll_steps ,
41
48
max_tokens = 2 * num_unroll_steps ,
42
49
context_length = 2 * infer_context_length ,
43
50
device = 'cuda' ,
44
51
action_space_size = action_space_size ,
45
52
# batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存
46
- num_layers = 12 ,
47
- num_heads = 24 ,
53
+ # num_layers=12,
54
+ # num_heads=24,
55
+
56
+ num_layers = 8 ,
57
+ num_heads = 8 ,
58
+
48
59
embed_dim = 768 ,
49
60
obs_type = 'image' ,
50
61
env_num = 8 ,
51
62
task_num = len (env_id_list ),
52
63
use_normal_head = True ,
53
64
use_softmoe_head = False ,
65
+ use_moe_head = False ,
66
+ num_experts_in_moe_head = 4 ,
54
67
moe_in_transformer = False ,
68
+ multiplication_moe_in_transformer = False ,
55
69
num_experts_of_moe_in_transformer = 4 ,
56
70
),
57
71
),
72
+ use_task_exploitation_weight = False , # TODO
73
+ task_complexity_weight = False , # TODO
58
74
total_batch_size = total_batch_size ,
59
75
allocated_batch_sizes = False ,
60
76
train_start_after_envsteps = int (0 ),
@@ -87,7 +103,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
87
103
norm_type , seed , buffer_reanalyze_freq , reanalyze_batch_size , reanalyze_partition ,
88
104
num_segments , total_batch_size ):
89
105
configs = []
90
- exp_name_prefix = f'data_unizero_mt_ddp-8gpu/ { len (env_id_list )} games_brf{ buffer_reanalyze_freq } _seed{ seed } /'
106
+ exp_name_prefix = f'data_unizero_atari_mt_20250212/atari_ { len (env_id_list )} games_brf{ buffer_reanalyze_freq } _seed{ seed } /'
91
107
92
108
for task_id , env_id in enumerate (env_id_list ):
93
109
config = create_config (
@@ -118,7 +134,7 @@ def create_env_manager():
118
134
Overview:
119
135
This script should be executed with <nproc_per_node> GPUs.
120
136
Run the following command to launch the script:
121
- python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
137
+ python -m torch.distributed.launch --nproc_per_node=5 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
122
138
torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py
123
139
"""
124
140
0 commit comments