Skip to content

Commit 581310b

Browse files
author
puyuan
committed
fix(pu): fix gradient accumulation_steps option
1 parent 51185e3 commit 581310b

File tree

4 files changed

+66
-38
lines changed

4 files changed

+66
-38
lines changed

lzero/policy/unizero.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _init_learn(self) -> None:
343343
wandb.watch(self._learn_model.representation_network, log="all")
344344

345345
# TODO: ========
346-
self.accumulation_steps = 4 # 累积的步数
346+
self.accumulation_steps = 1 # 累积的步数
347347

348348
# @profile
349349
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
@@ -467,8 +467,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
467467
# assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
468468

469469
# Core learn model update step
470-
if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
471-
# print(f'train_iter:{train_iter}')
470+
# print(f'train_iter:{train_iter}')
471+
# 假设 train_iter 是从 0 开始计数
472+
if (train_iter % self.accumulation_steps) == 0:
473+
# 每个累计周期的第一个step时清零梯度
474+
# print(f'train_iter:{train_iter} self._optimizer_world_model.zero_grad()')
472475
self._optimizer_world_model.zero_grad()
473476

474477
weighted_total_loss = weighted_total_loss / self.accumulation_steps # 累积梯度时对 loss 进行缩放
@@ -481,16 +484,30 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
481484
# if param.requires_grad:
482485
# print(name, param.grad.norm())
483486

484-
if self._cfg.analysis_sim_norm:
485-
del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
486-
self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
487-
self._target_model.encoder_hook.clear_data()
488-
489-
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
490-
self._cfg.grad_clip_value)
487+
# if self._cfg.analysis_sim_norm:
488+
# del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
489+
# self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
490+
# self._target_model.encoder_hook.clear_data()
491+
492+
# total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
493+
# self._cfg.grad_clip_value)
494+
495+
# 判断是否完成了一个累计周期(例如:如果 accumulation_steps=4, 则在 4,8,12,... 次迭代时更新参数)
496+
if (train_iter + 1) % self.accumulation_steps == 0:
497+
# print(f'train_iter:{train_iter} self._optimizer_world_model.step()')
498+
499+
# ========== 分析梯度模的代码 ==========
500+
if self._cfg.analysis_sim_norm:
501+
# 删除上次的分析结果,防止累积过多内存
502+
del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
503+
self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
504+
self._target_model.encoder_hook.clear_data()
505+
506+
# 对梯度进行裁剪
507+
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(
508+
self._learn_model.world_model.parameters(), self._cfg.grad_clip_value
509+
)
491510

492-
if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
493-
# print(f'pos 2 train_iter:{train_iter}')
494511

495512
if self._cfg.multi_gpu:
496513
self.sync_gradients(self._learn_model)
@@ -503,8 +520,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
503520
# Core target model update step
504521
self._target_model.update(self._learn_model.state_dict())
505522

506-
if self.accumulation_steps>1:
523+
if self.accumulation_steps > 1:
507524
torch.cuda.empty_cache()
525+
else:
526+
total_grad_norm_before_clip_wm = torch.tensor(0.)
508527

509528
if torch.cuda.is_available():
510529
torch.cuda.synchronize()

zoo/jericho/configs/jericho_ppo_config.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
action_space_size = 10
55
max_steps = 50
66
model_name = 'BAAI/bge-base-en-v1.5'
7-
env_id = 'detective.z5'
7+
# env_id = 'detective.z5'
8+
9+
action_space_size = 10
10+
max_steps = 400
11+
env_id = 'zork1.z5'
12+
813
evaluator_env_num = 2
914

1015
# proj train
@@ -22,14 +27,13 @@
2227
# num_unroll_steps = 5
2328
# infer_context_length = 2
2429
jericho_ppo_config = dict(
25-
# exp_name=f"data_ppo_detective/jericho_ppo_projtrain_bs{batch_size}_seed0",
26-
exp_name=f"data_ppo_detective_debug/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0",
30+
exp_name=f"data_ppo_detective/jericho_{env_id}_ms{max_steps}_ass{action_space_size}_ppo_projtrain_bs{batch_size}_seed0",
31+
# exp_name=f"data_ppo_detective_debug/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0",
2732
env=dict(
2833
remove_stuck_actions=False,
2934
# remove_stuck_actions=True,
30-
add_location_and_inventory=True,
31-
# add_location_and_inventory=False,
32-
35+
# add_location_and_inventory=True,
36+
add_location_and_inventory=False,
3337
stop_value=int(1e6),
3438
observation_shape=512,
3539
max_steps=max_steps,
@@ -60,13 +64,11 @@
6064
epoch_per_collect=4,
6165
batch_size=batch_size,
6266
learning_rate=0.0005,
63-
# entropy_weight=0.01,
6467
entropy_weight=0.05,
6568
value_norm=True,
6669
grad_clip_value=10,
6770
),
6871
collect=dict(
69-
# n_sample=1024,
7072
n_sample=320, # TODO: DEBUG
7173
discount_factor=0.99,
7274
gae_lambda=0.95,

zoo/jericho/configs/jericho_unizero_config.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def main(env_id='detective.z5', seed=0):
6969
# ==============================================================
7070
jericho_unizero_config = dict(
7171
env=dict(
72-
# remove_stuck_actions=False,
73-
remove_stuck_actions=True,
72+
remove_stuck_actions=False,
73+
# remove_stuck_actions=True,
7474

7575
stop_value=int(1e6),
7676
observation_shape=512,
@@ -167,7 +167,8 @@ def main(env_id='detective.z5', seed=0):
167167
main_config = jericho_unizero_config
168168
create_config = jericho_unizero_create_config
169169

170-
main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
170+
main_config.exp_name = f'data_unizero_detective_20250209/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep1_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
171+
# main_config.exp_name = f'data_unizero_detective_20250209/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
171172
# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
172173

173174
# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'

zoo/jericho/envs/jericho_env.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,19 @@ def prepare_obs(self, obs, return_str: bool = False):
9696

9797
if len(available_actions) <= self.max_action_num:
9898
action_mask = [1] * len(available_actions) + [0] * (self.max_action_num - len(available_actions))
99-
else:
99+
elif len(available_actions) == self.max_action_num:
100100
action_mask = [1] * len(available_actions)
101+
else:
102+
action_mask = [1] * self.max_action_num
101103

102104
action_mask = np.array(action_mask, dtype=np.int8)
103105

104-
if return_str: # TODO===============
105-
# return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
106-
return {'observation': full_obs, 'action_mask': action_mask}
106+
if return_str: # TODO: unizero需要加上'to_play'===============
107+
return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
108+
# return {'observation': full_obs, 'action_mask': action_mask}
107109
else:
108-
# return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1}
109-
return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask}
110+
return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1}
111+
# return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask}
110112

111113

112114
def reset(self, return_str: bool = False):
@@ -179,8 +181,8 @@ def step(self, action: int, return_str: bool = False):
179181
self.timestep += 1
180182
# print(f'step: {self.timestep}, [OBS]:{observation} self._action_list:{self._action_list}')
181183

182-
# TODO: for PPO
183-
reward = np.array([float(reward)])
184+
# TODO: for PPO, 如果是unizero需要注释下面这行
185+
# reward = np.array([float(reward)])
184186

185187
self.env_step += 1
186188
self.episode_return += reward
@@ -234,16 +236,20 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
234236
from easydict import EasyDict
235237
env_cfg = EasyDict(
236238
dict(
237-
max_steps=100,
239+
max_steps=400,
238240
# game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
239-
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5",
241+
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/zork1.z5",
242+
# game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5",
240243
# game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/905.z5",
241-
max_action_num=50,
242-
max_env_step=100,
244+
# max_action_num=50,
245+
max_action_num=10,
246+
# max_env_step=100,
243247
tokenizer_path="google-bert/bert-base-uncased",
244248
max_seq_len=512,
245-
remove_stuck_actions=True, # 启用移除无效动作的功能
246-
add_location_and_inventory=True
249+
remove_stuck_actions=False, # 启用移除无效动作的功能
250+
add_location_and_inventory=False
251+
# remove_stuck_actions=True, # 启用移除无效动作的功能
252+
# add_location_and_inventory=True
247253
)
248254
)
249255
env = JerichoEnv(env_cfg)

0 commit comments

Comments
 (0)