Skip to content

Commit 3346e08

Browse files
author
puyuan
committed
fix(pu): fix start_pos *2 bug
1 parent ccb21f4 commit 3346e08

File tree

2 files changed

+39
-26
lines changed

2 files changed

+39
-26
lines changed

lzero/model/unizero_world_models/transformer.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
5050

5151

5252
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
53+
# https://github.com/meta-llama/llama3/blob/main/llama/model.py#L61
5354
ndim = x.ndim
54-
# print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
55-
assert 0 <= 1 < ndim
56-
shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i, d in enumerate(x.shape)]
57-
55+
shape = [d if i == ndim - 1 or i == 2 or i == 0 else 1 for i, d in enumerate(x.shape)]
5856
return freqs_cis.view(*shape)
5957

6058

@@ -66,7 +64,9 @@ def apply_rotary_emb(
6664
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
6765
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
6866
try:
67+
# print(f"freqs_cis shape: {freqs_cis.shape}, xq_ shape: {xq_.shape}")
6968
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
69+
# print(f"new freqs_cis shape: {freqs_cis.shape}")
7070
except Exception as e:
7171
print(e)
7272
print('We are at the reset timestep!')
@@ -137,25 +137,31 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
137137
# 如果使用 RoPE,则对 freqs_cis 进行切片
138138
if self.config.rotary_emb:
139139
# 修复:如果 start_pos 是标量,则将其扩展为当前 batch 大小的相同数值
140-
# *2是由于timestep只是统计了obs,但是序列是obs act
140+
# t==========*2是由于timestep只是统计了obs,但是序列是obs act==========
141141
if isinstance(start_pos, int) or isinstance(start_pos, float):
142142
start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) * 2
143143
else:
144144
# start_pos_tensor = torch.as_tensor(start_pos, device=sequences.device)
145145
try:
146-
start_pos_tensor = torch.as_tensor([x.item() for x in start_pos], device=sequences.device)
146+
start_pos_tensor = torch.as_tensor([x.item() for x in start_pos], device=sequences.device) * 2
147147
except Exception as e:
148148
# print(e)
149149
start_pos_tensor = torch.as_tensor(
150150
[x.reshape(-1)[0].item() for x in start_pos], # 强制展平后取第一个元素
151151
device=sequences.device
152152
) * 2
153+
153154
# 对每个样本根据 start_pos 取对应区间的 freqs_cis
154155
start_pos_tensor = torch.remainder(start_pos_tensor, self.config.max_seq_len)
155156
# 将各个样本的 start_pos 转换为列表
156157
start_pos_list = start_pos_tensor.tolist()
157158
freqs_cis_slices = [self.freqs_cis[int(pos): int(pos) + seqlen] for pos in start_pos_list]
158159
freqs_cis = torch.stack(freqs_cis_slices)
160+
161+
if freqs_cis.ndim == 3 and freqs_cis.shape[1] == 1:
162+
# 将形状 [seq_len, 1, num_pairs] 转换为 [seq_len, num_pairs]
163+
freqs_cis = freqs_cis.squeeze(1)
164+
# print(f'165 freqs_cis.shape:{freqs_cis.shape}')
159165
else:
160166
freqs_cis = None
161167

@@ -307,8 +313,8 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
307313
for i in range(B):
308314
mask[i] = self.mask[L:L + T, :L + T].clone()
309315
mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0.
310-
# Adjust mask dimensions to match the last two dimensions of att.
311-
# (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T)
316+
# Adjust mask dimensions to match the last two dimensions of att.
317+
# (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T)
312318
mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1)
313319
else:
314320
# mask.shape: (T, L + T)

zoo/atari/config/atari_unizero_config.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,24 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
1313
game_segment_length = 20
1414
evaluator_env_num = 3
1515
num_simulations = 50
16-
max_env_step = int(5e5)
16+
max_env_step = int(4e5)
1717
batch_size = 64
1818
num_unroll_steps = 10
1919
infer_context_length = 4
2020
num_layers = 2
2121
replay_ratio = 0.25
2222

23-
# collector_env_num = 2
24-
# game_segment_length = 20
25-
# evaluator_env_num = 1
26-
# num_simulations = 2
27-
# max_env_step = int(5e5)
28-
# batch_size = 2
29-
# num_unroll_steps = 5
30-
# infer_context_length = 2
31-
# num_layers = 1
32-
# replay_ratio = 0.1
23+
# only for debug
24+
collector_env_num = 2
25+
game_segment_length = 20
26+
evaluator_env_num = 1
27+
num_simulations = 2
28+
max_env_step = int(5e5)
29+
batch_size = 2
30+
num_unroll_steps = 5
31+
infer_context_length = 2
32+
num_layers = 1
33+
replay_ratio = 0.1
3334
# ==============================================================
3435
# end of the most frequently changed config specified by the user
3536
# ==============================================================
@@ -44,16 +45,16 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
4445
n_evaluator_episode=evaluator_env_num,
4546
manager=dict(shared_memory=False, ),
4647
# TODO: only for debug
47-
# collect_max_episode_steps=int(50),
48-
# eval_max_episode_steps=int(50),
48+
collect_max_episode_steps=int(50),
49+
eval_max_episode_steps=int(50),
4950
),
5051
policy=dict(
5152
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000
5253
model=dict(
5354
observation_shape=(3, 96, 96),
5455
action_space_size=action_space_size,
5556
world_model_cfg=dict(
56-
policy_entropy_weight=5e-3,
57+
policy_entropy_weight=1e-4,
5758
continuous_action_space=False,
5859
max_blocks=num_unroll_steps,
5960
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
@@ -69,7 +70,10 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
6970
# rotary_emb=False,
7071
rotary_emb=True,
7172
rope_theta=10000,
72-
max_seq_len=2048,
73+
# max_seq_len=2048,
74+
# max_seq_len=4096,
75+
max_seq_len=int(4096*2),
76+
7377
),
7478
),
7579
model_path=None,
@@ -78,8 +82,8 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
7882
batch_size=batch_size,
7983
learning_rate=0.0001,
8084
num_simulations=num_simulations,
81-
train_start_after_envsteps=2000,
82-
# train_start_after_envsteps=0, # debug
85+
# train_start_after_envsteps=2000,
86+
train_start_after_envsteps=0, # debug
8387
game_segment_length=game_segment_length,
8488
replay_buffer_size=int(1e6),
8589
eval_freq=int(5e3),
@@ -104,7 +108,10 @@ def main(env_id='PongNoFrameskip-v4', seed=0):
104108
atari_unizero_create_config = EasyDict(atari_unizero_create_config)
105109
create_config = atari_unizero_create_config
106110

107-
main_config.exp_name = f'data_unizero_20250211/{env_id[:-14]}/{env_id[:-14]}_uz_rope-mergemain_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
111+
# main_config.exp_name = f'data_unizero_20250211/{env_id[:-14]}/{env_id[:-14]}_uz_posembed-mergemain_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
112+
# main_config.exp_name = f'data_unizero_20250211/{env_id[:-14]}/{env_id[:-14]}_uz_rope-mergemain-msl4096*2_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
113+
main_config.exp_name = f'data_unizero_20250211_debug/{env_id[:-14]}/{env_id[:-14]}_uz_rope-mergemain-msl4096*2_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
114+
108115
from lzero.entry import train_unizero
109116
train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
110117

0 commit comments

Comments
 (0)