Skip to content

Commit 243675b

Browse files
author
puyuan
committed
fix(pu): fix lr target_model_update bug when accumulation_steps>1
1 parent 4a0d1c6 commit 243675b

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

lzero/policy/unizero.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -514,17 +514,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
514514

515515
self._optimizer_world_model.step()
516516

517-
if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler:
518-
self.lr_scheduler.step()
519-
520-
# Core target model update step
521-
self._target_model.update(self._learn_model.state_dict())
522-
523517
if self.accumulation_steps > 1:
524518
torch.cuda.empty_cache()
525519
else:
526520
total_grad_norm_before_clip_wm = torch.tensor(0.)
527521

522+
if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler:
523+
self.lr_scheduler.step()
524+
# Core target model update step
525+
self._target_model.update(self._learn_model.state_dict())
526+
528527
if torch.cuda.is_available():
529528
torch.cuda.synchronize()
530529
current_memory_allocated = torch.cuda.memory_allocated()

zoo/jericho/configs/jericho_ppo_config.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
from easydict import EasyDict
22
import torch.nn as nn
33

4-
action_space_size = 10
5-
max_steps = 50
4+
65
model_name = 'BAAI/bge-base-en-v1.5'
6+
evaluator_env_num = 2
7+
78
# env_id = 'detective.z5'
9+
action_space_size = 10
10+
max_steps = 50
811

12+
env_id = 'zork1.z5'
913
action_space_size = 10
1014
max_steps = 400
11-
env_id = 'zork1.z5'
1215

13-
evaluator_env_num = 2
1416

1517
# proj train
1618
# collector_env_num = 18
1719
# batch_size = 320
18-
1920
collector_env_num = 4
2021
batch_size = 32
2122

zoo/jericho/envs/jericho_env.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def prepare_obs(self, obs, return_str: bool = False):
108108

109109
action_mask = np.array(action_mask, dtype=np.int8)
110110

111-
if return_str: # TODO: unizero需要加上'to_play'===============
111+
if return_str: # TODO: unizero需要加上'to_play', PPO不能加上'to_play'===============
112112
return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
113113
# return {'observation': full_obs, 'action_mask': action_mask}
114114
else:
@@ -172,7 +172,7 @@ def step(self, action: int, return_str: bool = False):
172172
action_str = self._action_list[action]
173173
else:
174174
action_str = 'go'
175-
print(f'rank {self.rank}, len(self._action_list) == 0, self._env.get_valid_actions():{self._env.get_valid_actions()}')
175+
print(f"rank {self.rank}, len(self._action_list) == 0, self._env.get_valid_actions():{self._env.get_valid_actions()}, so we pass action_str='go'")
176176

177177
# 记录上一次的观察
178178
if self.remove_stuck_actions and self.last_observation is not None:

0 commit comments

Comments
 (0)