Skip to content

Commit 6cd57dd

Browse files
author
puyuan
committed
polish(pu): add rope_embed support for cartpole
1 parent b991758 commit 6cd57dd

File tree

6 files changed

+99
-80
lines changed

6 files changed

+99
-80
lines changed

lzero/mcts/buffer/game_buffer_unizero.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
427427
m_output = model.initial_inference(m_obs, action_batch, start_pos=step_index_batch) # TODO: step_index
428428
# ======================================================================
429429

430-
if not model.training:
431-
# if not in training, obtain the scalars of the value/reward
432-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
433-
[
434-
m_output.latent_state,
435-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
436-
m_output.policy_logits
437-
]
438-
)
430+
# if not model.training:
431+
# if not in training, obtain the scalars of the value/reward
432+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
433+
[
434+
m_output.latent_state,
435+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
436+
m_output.policy_logits
437+
]
438+
)
439439
network_output.append(m_output)
440440

441441
# concat the output slices after model inference

lzero/model/unizero_world_models/transformer.py

+67-57
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,47 @@ class TransformerConfig:
3333
# for RoPE
3434
rope_theta: float
3535
max_seq_len: int
36+
rotary_emb: bool = False # 增加配置选项控制是否使用 rotary_emb
37+
3638
@property
3739
def max_tokens(self):
3840
return self.tokens_per_block * self.max_blocks
3941

4042

43+
44+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
45+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
46+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
47+
freqs = torch.outer(t, freqs)
48+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
49+
return freqs_cis
50+
51+
52+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
53+
ndim = x.ndim
54+
# print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
55+
assert 0 <= 1 < ndim
56+
shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i, d in enumerate(x.shape)]
57+
58+
return freqs_cis.view(*shape)
59+
60+
61+
def apply_rotary_emb(
62+
xq: torch.Tensor,
63+
xk: torch.Tensor,
64+
freqs_cis: torch.Tensor,
65+
) -> Tuple[torch.Tensor, torch.Tensor]:
66+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
67+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
68+
try:
69+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
70+
except:
71+
print('We are at the reset timestep!')
72+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
73+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
74+
return xq_out.type_as(xq), xk_out.type_as(xk)
75+
76+
4177
class Transformer(nn.Module):
4278
"""
4379
Transformer model class.
@@ -59,11 +95,14 @@ def __init__(self, config: TransformerConfig) -> None:
5995
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
6096
self.ln_f = nn.LayerNorm(config.embed_dim)
6197

62-
self.freqs_cis = precompute_freqs_cis(
63-
self.config.embed_dim // self.config.num_heads,
64-
self.config.max_seq_len * 2,
65-
self.config.rope_theta,
66-
)
98+
# 注册缓存, 自动管理设备转换
99+
if self.config.rotary_emb:
100+
freqs_cis = precompute_freqs_cis(
101+
self.config.embed_dim // self.config.num_heads,
102+
self.config.max_seq_len * 2,
103+
self.config.rope_theta,
104+
)
105+
self.register_buffer("freqs_cis", freqs_cis)
67106

68107
def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
69108
"""
@@ -93,24 +132,31 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
93132
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
94133
"""
95134
seqlen = sequences.shape[1]
96-
self.freqs_cis = self.freqs_cis.to(sequences.device)
97135

98-
# freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
99-
100-
# If the start position is greater than the predefined maximum sequence length, wrap around
101-
start_pos = torch.tensor(np.array(start_pos))
102-
if len(start_pos.shape) > 1:
103-
# TODO: train start pos [0]
104-
start_pos = torch.remainder(start_pos, self.config.max_seq_len)[:,0]
136+
# 如果使用 RoPE,则对 freqs_cis 进行切片
137+
if self.config.rotary_emb:
138+
# 修复:如果 start_pos 是标量,则将其扩展为当前 batch 大小的相同数值
139+
# *2是由于step_index只是统计了obs,但是序列是obs act
140+
if isinstance(start_pos, int) or isinstance(start_pos, float):
141+
start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) * 2
142+
else:
143+
# start_pos_tensor = torch.as_tensor(start_pos, device=sequences.device)
144+
try:
145+
start_pos_tensor = torch.as_tensor([x.item() for x in start_pos], device=sequences.device)
146+
except Exception as e:
147+
# print(e)
148+
start_pos_tensor = torch.as_tensor(
149+
[x.reshape(-1)[0].item() for x in start_pos], # 强制展平后取第一个元素
150+
device=sequences.device
151+
) * 2
152+
# 对每个样本根据 start_pos 取对应区间的 freqs_cis
153+
start_pos_tensor = torch.remainder(start_pos_tensor, self.config.max_seq_len)
154+
# 将各个样本的 start_pos 转换为列表
155+
start_pos_list = start_pos_tensor.tolist()
156+
freqs_cis_slices = [self.freqs_cis[int(pos): int(pos) + seqlen] for pos in start_pos_list]
157+
freqs_cis = torch.stack(freqs_cis_slices)
105158
else:
106-
start_pos = torch.remainder(start_pos, self.config.max_seq_len)
107-
108-
start_pos_list = torch.unbind(start_pos)
109-
try:
110-
freqs_cis_slices = [self.freqs_cis[int(pos.item()): int(pos.item()) + seqlen] for pos in start_pos_list]
111-
except:
112-
print('debug')
113-
freqs_cis = torch.stack(freqs_cis_slices).squeeze(1)
159+
freqs_cis = None
114160

115161
assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
116162
x = self.drop(sequences)
@@ -181,42 +227,6 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
181227
return x
182228

183229

184-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
185-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
186-
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
187-
freqs = torch.outer(t, freqs)
188-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
189-
return freqs_cis
190-
191-
192-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
193-
ndim = x.ndim
194-
# print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
195-
assert 0 <= 1 < ndim
196-
# assert freqs_cis.shape == (x.shape[2], x.shape[-1])
197-
# shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
198-
# TODO: check
199-
shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i, d in enumerate(x.shape)]
200-
201-
return freqs_cis.view(*shape)
202-
203-
204-
def apply_rotary_emb(
205-
xq: torch.Tensor,
206-
xk: torch.Tensor,
207-
freqs_cis: torch.Tensor,
208-
) -> Tuple[torch.Tensor, torch.Tensor]:
209-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
210-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
211-
try:
212-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
213-
except:
214-
print('We are at the reset timestep!')
215-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
216-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
217-
return xq_out.type_as(xq), xk_out.type_as(xk)
218-
219-
220230
class SelfAttention(nn.Module):
221231
"""
222232
Implements self-attention mechanism for transformers.

lzero/model/unizero_world_models/world_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
5959
if not self.config.rotary_emb:
6060
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
6161
self.precompute_pos_emb_diff_kv()
62-
6362
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
6463

6564
# Initialize action embedding table
@@ -488,7 +487,8 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor
488487
self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True)
489488

490489
buffer_action = buffer_action[:ready_env_num]
491-
# TODO
490+
491+
# TODO: 顺序可能不对?
492492
start_pos = start_pos[:ready_env_num]
493493

494494
# if ready_env_num < self.env_num:

lzero/policy/unizero.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
723723
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, step_index)
724724
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
725725

726-
if not self._eval_model.training:
727-
# if not in training, obtain the scalars of the value/reward
728-
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
729-
latent_state_roots = latent_state_roots.detach().cpu().numpy()
730-
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
726+
# if not self._eval_model.training:
727+
# if not in training, obtain the scalars of the value/reward
728+
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
729+
latent_state_roots = latent_state_roots.detach().cpu().numpy()
730+
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
731731

732732
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)]
733733
if self._cfg.mcts_ctree:

zoo/classic_control/cartpole/config/cartpole_unizero_config.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
update_per_collect = None
1010
replay_ratio = 0.25
1111
max_env_step = int(2e5)
12-
reanalyze_ratio = 0
1312
batch_size = 256
1413
num_unroll_steps = 5
1514
# ==============================================================
1615
# end of the most frequently changed config specified by the user
1716
# ==============================================================
1817

1918
cartpole_unizero_config = dict(
20-
exp_name=f'data_unizero/cartpole_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_seed0',
19+
exp_name=f'data_unizero_debug/cartpole_unizero_pos-embed_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed0',
2120
env=dict(
2221
env_name='CartPole-v0',
2322
continuous=False,
@@ -40,16 +39,21 @@
4039
max_tokens=2 * 10,
4140
context_length=2 * 4,
4241
context_length_for_recurrent=2 * 4,
43-
device='cpu',
42+
device='cuda',
4443
action_space_size=2,
4544
num_layers=2,
4645
num_heads=2,
4746
embed_dim=64,
48-
env_num=collector_env_num,
47+
env_num=max(collector_env_num, evaluator_env_num),
4948
collector_env_num=collector_env_num,
5049
evaluator_env_num=evaluator_env_num,
5150
obs_type='vector',
5251
norm_type='BN',
52+
# for RoPE
53+
rotary_emb=False,
54+
# rotary_emb=True,
55+
rope_theta=10000,
56+
max_seq_len=2048,
5357
),
5458
),
5559
# (str) The path of the pretrained model. If None, the model will be initialized by the default model.
@@ -67,7 +71,6 @@
6771
target_update_freq=100,
6872
grad_clip_value=5,
6973
num_simulations=num_simulations,
70-
reanalyze_ratio=reanalyze_ratio,
7174
n_episode=n_episode,
7275
eval_freq=int(1e3),
7376
replay_buffer_size=int(1e6),

zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(self, cfg: dict = {}) -> None:
5151
self._action_space = gym.spaces.Discrete(2)
5252
self._action_space.seed(0) # default seed
5353
self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)
54+
self.step_index = 0
55+
5456

5557
def reset(self) -> Dict[str, np.ndarray]:
5658
"""
@@ -86,7 +88,9 @@ def reset(self) -> Dict[str, np.ndarray]:
8688
obs = to_ndarray(obs)
8789

8890
action_mask = np.ones(self.action_space.n, 'int8')
89-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
91+
self.step_index = 0
92+
93+
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index}
9094

9195
return obs
9296

@@ -120,7 +124,9 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
120124
info['eval_episode_return'] = self._eval_episode_return
121125

122126
action_mask = np.ones(self.action_space.n, 'int8')
123-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
127+
self.step_index += 1
128+
129+
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index}
124130

125131
return BaseEnvTimestep(obs, rew, done, info)
126132

0 commit comments

Comments
 (0)