Skip to content

Commit 8fe1a6d

Browse files
author
puyuan
committed
feature(pu): add register_token_shared option
1 parent 557e8f9 commit 8fe1a6d

File tree

3 files changed

+79
-50
lines changed

3 files changed

+79
-50
lines changed

lzero/model/unizero_world_models/kv_caching.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi
140140
self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
141141
self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
142142

143-
self.register_token_num = 2 # Number of register tokens TODO======
143+
# self.register_token_num = 2 # Number of register tokens TODO======
144144

145-
def set_register_token_num(self, num: int) -> None:
146-
"""Set the number of register tokens."""
147-
self.register_token_num = num
145+
# def set_register_token_num(self, num: int) -> None:
146+
# """Set the number of register tokens."""
147+
# self.register_token_num = num
148148

149149
@property
150150
def shape(self) -> Tuple[int, int, int, int]:

lzero/model/unizero_world_models/transformer.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,29 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None:
6060

6161
self.task_embed = task_embed
6262
self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings
63+
self.register_token_shared = True
64+
65+
# TODO: 共享模式下,所有任务使用同一参数
66+
6367
if self.task_embed_option == "register_task_embed":
6468
self.use_register_token = True # TODO
6569
# Register token setup
6670
self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4
67-
self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings
71+
72+
# 判断是否采用共享模式
73+
self.register_token_shared = getattr(config, "register_token_shared", True)
74+
if self.register_token_shared:
75+
# print(f'self.register_token_shared:{self.register_token_shared}')
76+
# print(f'='*20)
77+
# 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim)
78+
self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim))
79+
nn.init.xavier_uniform_(self.register_tokens)
80+
else:
81+
# 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding,
82+
# 并通过 SimNorm 归一化后复制出 register token
83+
self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding
84+
self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings
85+
6886
else:
6987
self.use_register_token = False # TODO
7088

@@ -83,16 +101,19 @@ def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Te
83101
B = sequences.size(0)
84102
device = sequences.device
85103

86-
# 生成一个可学习的 task embedding
87-
# 并进行 SimNorm
88-
task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, C)
89-
task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (C, )
90-
# 扩展出 register_token_num
91-
register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, C)
92-
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, C)
104+
if self.register_token_shared:
105+
# 共享模式:直接使用同一组 register_tokens 参数
106+
# register_tokens 形状为 (register_token_num, embed_dim)
107+
register_tokens = self.register_tokens
108+
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim)
109+
else:
110+
# 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens
111+
task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim)
112+
task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,)
113+
register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim)
114+
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim)
93115

94-
# 拼接:将 Register Token 拼到最后面
95-
new_sequences = torch.cat([sequences, register_tokens], dim=1) # (B, register_token_num + T, C)
116+
new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C)
96117
return new_sequences
97118

98119
def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None:

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py

+44-36
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
5454
obs_type='vector',
5555
# use_shared_projection=True, # TODO
5656
use_shared_projection=False,
57-
task_embed_option='concat_task_embed', # ==============TODO: none ==============
58-
# task_embed_option='register_task_embed', # ==============TODO: none ==============
57+
# task_embed_option='concat_task_embed', # ==============TODO: none ==============
58+
task_embed_option='register_task_embed', # ==============TODO: none ==============
59+
5960
# task_embed_option=None, # ==============TODO: none ==============
60-
register_token_num=2, # TODO: 修改kv_caching中的register_token_num
61+
# register_token_num=4,
62+
register_token_num=2,
63+
6164
use_task_embed=True, # TODO
6265
# use_task_embed=False, # ==============TODO==============
6366
num_unroll_steps=num_unroll_steps,
@@ -77,11 +80,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
7780
# num_layers=2,
7881
# num_layers=4, # TODO
7982

80-
# num_layers=8, # TODO
81-
# num_heads=8,
83+
num_layers=8, # TODO
84+
num_heads=8,
8285

83-
num_layers=12, # TODO
84-
num_heads=12,
86+
# num_layers=12, # TODO
87+
# num_heads=12,
8588

8689
embed_dim=768,
8790
env_num=max(collector_env_num, evaluator_env_num),
@@ -158,9 +161,9 @@ def generate_configs(env_id_list: List[str],
158161
total_batch_size: int):
159162
configs = []
160163
# TODO: debug
161-
exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
164+
# exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
162165

163-
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-4_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
166+
exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu_nlayer8_{len(env_id_list)}tasks_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-shared_infer{infer_context_length}_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
164167

165168
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
166169
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
@@ -273,26 +276,26 @@ def create_env_manager():
273276
]
274277

275278
# DMC 18games
276-
env_id_list = [
277-
'acrobot-swingup',
278-
'cartpole-balance',
279-
'cartpole-balance_sparse',
280-
'cartpole-swingup',
281-
'cartpole-swingup_sparse',
282-
'cheetah-run',
283-
"ball_in_cup-catch",
284-
"finger-spin",
285-
"finger-turn_easy",
286-
"finger-turn_hard",
287-
'hopper-hop',
288-
'hopper-stand',
289-
'pendulum-swingup',
290-
'reacher-easy',
291-
'reacher-hard',
292-
'walker-run',
293-
'walker-stand',
294-
'walker-walk',
295-
]
279+
# env_id_list = [
280+
# 'acrobot-swingup',
281+
# 'cartpole-balance',
282+
# 'cartpole-balance_sparse',
283+
# 'cartpole-swingup',
284+
# 'cartpole-swingup_sparse',
285+
# 'cheetah-run',
286+
# "ball_in_cup-catch",
287+
# "finger-spin",
288+
# "finger-turn_easy",
289+
# "finger-turn_hard",
290+
# 'hopper-hop',
291+
# 'hopper-stand',
292+
# 'pendulum-swingup',
293+
# 'reacher-easy',
294+
# 'reacher-hard',
295+
# 'walker-run',
296+
# 'walker-stand',
297+
# 'walker-walk',
298+
# ]
296299

297300
# 获取各环境的 action_space_size 和 observation_shape
298301
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
@@ -303,20 +306,25 @@ def create_env_manager():
303306
n_episode = 8
304307
evaluator_env_num = 3
305308
num_simulations = 50
306-
# max_env_step = int(5e5)
307-
max_env_step = int(1e6)
309+
max_env_step = int(5e5)
310+
# max_env_step = int(1e6)
308311
reanalyze_ratio = 0.0
309312

310-
# nlayer=4
313+
# nlayer=4/8
311314
total_batch_size = 512
312315
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
313316

314-
# nlayer=8/12
315-
total_batch_size = 256
316-
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
317+
# nlayer=12
318+
# total_batch_size = 256
319+
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
317320

318321
num_unroll_steps = 5
319-
infer_context_length = 4 # 尾部有4个register token, 相当于infer_context_length还是2
322+
infer_context_length = 5 # 尾部有4个register token, kv_cache里面已经去掉了
323+
324+
# 原始设置
325+
# num_unroll_steps = 5
326+
# infer_context_length = 2
327+
320328
norm_type = 'LN'
321329
buffer_reanalyze_freq = 1 / 100000
322330
reanalyze_batch_size = 160

0 commit comments

Comments
 (0)