Skip to content

Commit f00603b

Browse files
author
puyuan
committed
tmp
1 parent 4c45d77 commit f00603b

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
7070

7171
# NOTE: gato-79M (small) transformer
7272
# batch_size=64 8games训练时,每张卡大约占12*2=24G cuda存储
73-
# num_layers=8,
74-
# num_heads=24,
75-
# embed_dim=768,
73+
num_layers=8,
74+
num_heads=24,
75+
embed_dim=768,
7676

7777
# NOTE: gato-medium 修改版 transformer
7878
# batch_size=64 8games训练时,每张卡大约占12*3=36G cuda存储
@@ -83,9 +83,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
8383
# NOTE: gato-medium 修改版 transformer
8484
# batch_size=64 8games训练时,每张卡大约占12*2*4 cuda存储
8585
# batch_size=32 8games训练时,每张卡大约占12*2*4/2 cuda存储
86-
num_layers=8,
87-
num_heads=24,
88-
embed_dim=1536,
86+
# num_layers=8,
87+
# num_heads=24,
88+
# embed_dim=1536,
8989

9090
# NOTE: gato-364M (medium) transformer
9191
# batch_size=64 8games训练时,每张卡大约占12*3*4 cuda存储
@@ -164,7 +164,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
164164
# exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1124/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/'
165165
# exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1124/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer12-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer12-nh24_mbs-512-bs64_upc80_seed{seed}/'
166166

167-
exp_name_prefix = f'data_unizero_mt_ddp-8gpu_1127/{len(env_id_list)}games_eval60min_brf{buffer_reanalyze_freq}_nlayer8-nhead24-embed1536_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer8-nhead24-embed1536_mbs-256-bs32_upc80_seed{seed}/'
167+
# exp_name_prefix = f'data_unizero_atari_mt_20250217/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24-embed768_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_seed{seed}/'
168+
exp_name_prefix = f'data_unizero_atari_mt_20250217/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24-embed768_seed{seed}/'
168169

169170
for task_id, env_id in enumerate(env_id_list):
170171
config = create_config(
@@ -272,8 +273,8 @@ def create_env_manager():
272273
os.environ["NCCL_TIMEOUT"] = "3600000000"
273274

274275
# for seed in [2, 3, 0, 1]: # TODO
275-
for seed in [0, 1, 2]: # TODO
276-
# for seed in [1]: # TODO
276+
# for seed in [0, 1, 2]: # TODO
277+
for seed in [0]: # TODO
277278
# for seed in [2,3]: # TODO
278279

279280
collector_env_num = 8
@@ -289,14 +290,14 @@ def create_env_manager():
289290
# total_batch_size = 2048
290291

291292
#应该根据一个样本sequence的占用显存量,和最大显存来设置
292-
total_batch_size = 256
293-
batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
294-
print(f'=========== batch_size: {batch_size} ===========')
295-
296-
# total_batch_size = 512
297-
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
293+
# total_batch_size = 256
294+
# batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
298295
# print(f'=========== batch_size: {batch_size} ===========')
299296

297+
total_batch_size = 512
298+
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
299+
print(f'=========== batch_size: {batch_size} ===========')
300+
300301
num_unroll_steps = 10
301302
infer_context_length = 4
302303
norm_type = 'LN'

0 commit comments

Comments
 (0)