@@ -54,10 +54,13 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
54
54
obs_type = 'vector' ,
55
55
# use_shared_projection=True, # TODO
56
56
use_shared_projection = False ,
57
- task_embed_option = 'concat_task_embed' , # ==============TODO: none ==============
58
- # task_embed_option='register_task_embed', # ==============TODO: none ==============
57
+ # task_embed_option='concat_task_embed', # ==============TODO: none ==============
58
+ task_embed_option = 'register_task_embed' , # ==============TODO: none ==============
59
+
59
60
# task_embed_option=None, # ==============TODO: none ==============
60
- register_token_num = 2 , # TODO: 修改kv_caching中的register_token_num
61
+ # register_token_num=4,
62
+ register_token_num = 2 ,
63
+
61
64
use_task_embed = True , # TODO
62
65
# use_task_embed=False, # ==============TODO==============
63
66
num_unroll_steps = num_unroll_steps ,
@@ -77,11 +80,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
77
80
# num_layers=2,
78
81
# num_layers=4, # TODO
79
82
80
- # num_layers=8, # TODO
81
- # num_heads=8,
83
+ num_layers = 8 , # TODO
84
+ num_heads = 8 ,
82
85
83
- num_layers = 12 , # TODO
84
- num_heads = 12 ,
86
+ # num_layers=12, # TODO
87
+ # num_heads=12,
85
88
86
89
embed_dim = 768 ,
87
90
env_num = max (collector_env_num , evaluator_env_num ),
@@ -158,9 +161,9 @@ def generate_configs(env_id_list: List[str],
158
161
total_batch_size : int ):
159
162
configs = []
160
163
# TODO: debug
161
- exp_name_prefix = f'data_suz_mt_20250123/{ len (env_id_list )} tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{ buffer_reanalyze_freq } _tbs{ total_batch_size } _seed{ seed } /'
164
+ # exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
162
165
163
- # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no -taskweight-obsloss-temp1_register-task-embed-4_{len(env_id_list)}tasks_brf {buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
166
+ exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu_nlayer8_ { len ( env_id_list ) } tasks_upc200_no -taskweight-obsloss-temp1_register-task-embed-2-shared_infer { infer_context_length } _brf { buffer_reanalyze_freq } _tbs{ total_batch_size } _seed{ seed } /'
164
167
165
168
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
166
169
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
@@ -273,26 +276,26 @@ def create_env_manager():
273
276
]
274
277
275
278
# DMC 18games
276
- env_id_list = [
277
- 'acrobot-swingup' ,
278
- 'cartpole-balance' ,
279
- 'cartpole-balance_sparse' ,
280
- 'cartpole-swingup' ,
281
- 'cartpole-swingup_sparse' ,
282
- 'cheetah-run' ,
283
- "ball_in_cup-catch" ,
284
- "finger-spin" ,
285
- "finger-turn_easy" ,
286
- "finger-turn_hard" ,
287
- 'hopper-hop' ,
288
- 'hopper-stand' ,
289
- 'pendulum-swingup' ,
290
- 'reacher-easy' ,
291
- 'reacher-hard' ,
292
- 'walker-run' ,
293
- 'walker-stand' ,
294
- 'walker-walk' ,
295
- ]
279
+ # env_id_list = [
280
+ # 'acrobot-swingup',
281
+ # 'cartpole-balance',
282
+ # 'cartpole-balance_sparse',
283
+ # 'cartpole-swingup',
284
+ # 'cartpole-swingup_sparse',
285
+ # 'cheetah-run',
286
+ # "ball_in_cup-catch",
287
+ # "finger-spin",
288
+ # "finger-turn_easy",
289
+ # "finger-turn_hard",
290
+ # 'hopper-hop',
291
+ # 'hopper-stand',
292
+ # 'pendulum-swingup',
293
+ # 'reacher-easy',
294
+ # 'reacher-hard',
295
+ # 'walker-run',
296
+ # 'walker-stand',
297
+ # 'walker-walk',
298
+ # ]
296
299
297
300
# 获取各环境的 action_space_size 和 observation_shape
298
301
action_space_size_list = [dmc_state_env_action_space_map [env_id ] for env_id in env_id_list ]
@@ -303,20 +306,25 @@ def create_env_manager():
303
306
n_episode = 8
304
307
evaluator_env_num = 3
305
308
num_simulations = 50
306
- # max_env_step = int(5e5)
307
- max_env_step = int (1e6 )
309
+ max_env_step = int (5e5 )
310
+ # max_env_step = int(1e6)
308
311
reanalyze_ratio = 0.0
309
312
310
- # nlayer=4
313
+ # nlayer=4/8
311
314
total_batch_size = 512
312
315
batch_size = [int (min (64 , total_batch_size / len (env_id_list ))) for _ in range (len (env_id_list ))]
313
316
314
- # nlayer=8/ 12
315
- total_batch_size = 256
316
- batch_size = [int (min (64 , total_batch_size / len (env_id_list ))) for _ in range (len (env_id_list ))]
317
+ # nlayer=12
318
+ # total_batch_size = 256
319
+ # batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
317
320
318
321
num_unroll_steps = 5
319
- infer_context_length = 4 # 尾部有4个register token, 相当于infer_context_length还是2
322
+ infer_context_length = 5 # 尾部有4个register token, kv_cache里面已经去掉了
323
+
324
+ # 原始设置
325
+ # num_unroll_steps = 5
326
+ # infer_context_length = 2
327
+
320
328
norm_type = 'LN'
321
329
buffer_reanalyze_freq = 1 / 100000
322
330
reanalyze_batch_size = 160
0 commit comments