@@ -70,9 +70,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
70
70
71
71
# NOTE: gato-79M (small) transformer
72
72
# batch_size=64 8games训练时,每张卡大约占12*2=24G cuda存储
73
- # num_layers=8,
74
- # num_heads=24,
75
- # embed_dim=768,
73
+ num_layers = 8 ,
74
+ num_heads = 24 ,
75
+ embed_dim = 768 ,
76
76
77
77
# NOTE: gato-medium 修改版 transformer
78
78
# batch_size=64 8games训练时,每张卡大约占12*3=36G cuda存储
@@ -83,9 +83,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
83
83
# NOTE: gato-medium 修改版 transformer
84
84
# batch_size=64 8games训练时,每张卡大约占12*2*4 cuda存储
85
85
# batch_size=32 8games训练时,每张卡大约占12*2*4/2 cuda存储
86
- num_layers = 8 ,
87
- num_heads = 24 ,
88
- embed_dim = 1536 ,
86
+ # num_layers=8,
87
+ # num_heads=24,
88
+ # embed_dim=1536,
89
89
90
90
# NOTE: gato-364M (medium) transformer
91
91
# batch_size=64 8games训练时,每张卡大约占12*3*4 cuda存储
@@ -164,7 +164,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
164
164
# exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1124/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/'
165
165
# exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1124/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer12-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer12-nh24_mbs-512-bs64_upc80_seed{seed}/'
166
166
167
- exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1127/{ len (env_id_list )} games_eval60min_brf{ buffer_reanalyze_freq } _nlayer8-nhead24-embed1536_seed{ seed } /{ len (env_id_list )} games_brf{ buffer_reanalyze_freq } _1-encoder-{ norm_type } -res2-channel256_gsl20_{ len (env_id_list )} -pred-head_nlayer8-nhead24-embed1536_mbs-256-bs32_upc80_seed{ seed } /'
167
+ # exp_name_prefix = f'data_unizero_atari_mt_20250217/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24-embed768_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_seed{seed}/'
168
+ exp_name_prefix = f'data_unizero_atari_mt_20250217/{ len (env_id_list )} games_brf{ buffer_reanalyze_freq } _nlayer8-nhead24-embed768_seed{ seed } /'
168
169
169
170
for task_id , env_id in enumerate (env_id_list ):
170
171
config = create_config (
@@ -272,8 +273,8 @@ def create_env_manager():
272
273
os .environ ["NCCL_TIMEOUT" ] = "3600000000"
273
274
274
275
# for seed in [2, 3, 0, 1]: # TODO
275
- for seed in [0 , 1 , 2 ]: # TODO
276
- # for seed in [1 ]: # TODO
276
+ # for seed in [0, 1, 2]: # TODO
277
+ for seed in [0 ]: # TODO
277
278
# for seed in [2,3]: # TODO
278
279
279
280
collector_env_num = 8
@@ -289,14 +290,14 @@ def create_env_manager():
289
290
# total_batch_size = 2048
290
291
291
292
#应该根据一个样本sequence的占用显存量,和最大显存来设置
292
- total_batch_size = 256
293
- batch_size = [int (min (32 , total_batch_size / len (env_id_list ))) for _ in range (len (env_id_list ))]
294
- print (f'=========== batch_size: { batch_size } ===========' )
295
-
296
- # total_batch_size = 512
297
- # batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
293
+ # total_batch_size = 256
294
+ # batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
298
295
# print(f'=========== batch_size: {batch_size} ===========')
299
296
297
+ total_batch_size = 512
298
+ batch_size = [int (min (64 , total_batch_size / len (env_id_list ))) for _ in range (len (env_id_list ))]
299
+ print (f'=========== batch_size: { batch_size } ===========' )
300
+
300
301
num_unroll_steps = 10
301
302
infer_context_length = 4
302
303
norm_type = 'LN'
0 commit comments