Skip to content

Commit b7c9295

Browse files
author
puyuan
committed
fix(pu): fix encoder dormant_ratio
1 parent ed9d7c1 commit b7c9295

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

lzero/model/unizero_world_models/world_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
9898
self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size)
9999

100100
# 对于 head 部分,查找所有以 "head_" 开头的子模块
101-
self.head_modules = {}
101+
self.head_dict = {}
102102
for name, module in self.named_children():
103103
if name.startswith("head_"):
104-
self.head_modules[name] = module
105-
if self.head_modules:
106-
self.head_modules = nn.ModuleDict(self.head_modules)
104+
self.head_dict[name] = module
105+
if self.head_dict:
106+
self.head_dict = nn.ModuleDict(self.head_dict)
107107

108108
# Apply weight initialization, the order is important
109109
self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type))
@@ -1171,8 +1171,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
11711171
# 计算全局平均权重绝对值
11721172
avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer)
11731173
# print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer)
1174-
# print(f"self.head_modules:{self.head_modules}")
1175-
avg_weight_mag_head = compute_average_weight_magnitude(self.head_modules)
1174+
# print(f"self.head_dict:{self.head_dict}")
1175+
avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict)
11761176
# print("Average Weight Magnitude of head:", avg_weight_mag_head)
11771177

11781178
# 计算 effective rank,对于 representation 层,注意:

lzero/model/utils.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,19 @@ def cal_dormant_ratio(
187187
parts["transformer"] = model.transformer
188188

189189
# 对于 head 部分,查找所有以 "head_" 开头的子模块
190-
# head_modules = {}
190+
# head_dict = {}
191191
# for name, module in model.named_children():
192192
# if name.startswith("head_"):
193-
# head_modules[name] = module
194-
# if head_modules:
195-
# parts["head"] = nn.ModuleDict(head_modules)
193+
# head_dict[name] = module
194+
# if head_dict:
195+
# parts["head"] = nn.ModuleDict(head_dict)
196196

197-
if hasattr(model, "head_modules"):
198-
parts["head"] = model.head_modules
197+
if hasattr(model, "head_dict"):
198+
parts["head"] = model.head_dict
199199

200-
# if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"):
201-
# parts["model"] = model
200+
if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"):
201+
# 如果传入的是self.tokenizer.encoder
202+
parts["model"] = model
202203

203204
# 定义要捕获的目标模块类型 TODO: 增加更多模块
204205
target_modules = (nn.Conv2d, nn.Linear)
@@ -235,6 +236,8 @@ def cal_dormant_ratio(
235236
part_dormant = 0
236237
for full_name, hook in hooks:
237238
layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold)
239+
# if part == "model":
240+
# print(hook.outputs)
238241
# 可打印日志,也可记录更详细信息
239242
# print(f"{full_name}: {layer_dormant}/{layer_total} -> {layer_dormant / layer_total * 100.0 if layer_total > 0 else 0.0}%")
240243
part_total += layer_total

zoo/atari/config/atari_unizero_segment_config.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def main(env_id, seed):
4949
n_evaluator_episode=evaluator_env_num,
5050
manager=dict(shared_memory=False, ),
5151
# TODO: only for debug
52-
# collect_max_episode_steps=int(50),
53-
# eval_max_episode_steps=int(50),
52+
# collect_max_episode_steps=int(20),
53+
# eval_max_episode_steps=int(20),
5454
),
5555
policy=dict(
5656
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000
@@ -102,6 +102,7 @@ def main(env_id, seed):
102102
num_simulations=num_simulations,
103103
num_segments=num_segments,
104104
td_steps=5,
105+
# train_start_after_envsteps=0, # only for debug
105106
train_start_after_envsteps=2000,
106107
game_segment_length=game_segment_length,
107108
grad_clip_value=5,
@@ -137,6 +138,7 @@ def main(env_id, seed):
137138

138139
# ============ use muzero_segment_collector instead of muzero_collector =============
139140
from lzero.entry import train_unizero_segment
141+
# TODO: only for debug
140142
main_config.exp_name = f'data_unizero_atari_st_lop/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
141143
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
142144

0 commit comments

Comments
 (0)