Skip to content

Commit ed9d7c1

Browse files
author
puyuan
committed
feature(pu): add analysis_dormant_ratio_weight_rank option in single-task setting
1 parent 3a25c08 commit ed9d7c1

File tree

4 files changed

+214
-31
lines changed

4 files changed

+214
-31
lines changed

lzero/model/unizero_world_models/world_model.py

+47-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.distributions import TransformedDistribution, TanhTransform
1313

1414
from lzero.model.common import SimNorm
15-
from lzero.model.utils import cal_dormant_ratio
15+
from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank
1616
from .kv_caching import KeysValues
1717
from .slicer import Head, PolicyHeadCont
1818
from .tokenizer import Tokenizer
@@ -97,6 +97,14 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
9797
self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size)
9898
self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size)
9999

100+
# 对于 head 部分,查找所有以 "head_" 开头的子模块
101+
self.head_modules = {}
102+
for name, module in self.named_children():
103+
if name.startswith("head_"):
104+
self.head_modules[name] = module
105+
if self.head_modules:
106+
self.head_modules = nn.ModuleDict(self.head_modules)
107+
100108
# Apply weight initialization, the order is important
101109
self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type))
102110
self._initialize_last_layer()
@@ -259,7 +267,7 @@ def _initialize_config_parameters(self) -> None:
259267
self.gamma = self.config.gamma
260268
self.context_length = self.config.context_length
261269
self.dormant_threshold = self.config.dormant_threshold
262-
self.analysis_dormant_ratio = self.config.analysis_dormant_ratio
270+
self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank
263271
self.num_observations_tokens = self.config.tokens_per_block - 1
264272
self.latent_recon_loss_weight = self.config.latent_recon_loss_weight
265273
self.perceptual_loss_weight = self.config.perceptual_loss_weight
@@ -1149,18 +1157,43 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
11491157
# self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne')
11501158

11511159
# ========= logging for analysis =========
1152-
if self.analysis_dormant_ratio:
1160+
if self.analysis_dormant_ratio_weight_rank:
11531161
# Calculate dormant ratio of the encoder
11541162
shape = batch['observations'].shape # (..., C, H, W)
11551163
inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64)
11561164
dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.encoder, inputs.detach(),
11571165
dormant_threshold=self.dormant_threshold)
11581166
dormant_ratio_encoder = dormant_ratio_encoder['global']
1167+
1168+
# 计算全局平均权重绝对值
1169+
avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder)
1170+
# print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder)
1171+
# 计算全局平均权重绝对值
1172+
avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer)
1173+
# print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer)
1174+
# print(f"self.head_modules:{self.head_modules}")
1175+
avg_weight_mag_head = compute_average_weight_magnitude(self.head_modules)
1176+
# print("Average Weight Magnitude of head:", avg_weight_mag_head)
1177+
1178+
# 计算 effective rank,对于 representation 层,注意:
1179+
# representation 层在 model.named_modules() 的名称为 "representation"
1180+
# print(f"self.tokenizer.encoder:{self.tokenizer.encoder}")
1181+
e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="last_linear")
1182+
# print("Effective Rank of encoder_last_linear:", e_rank_last_linear)
1183+
e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="sim_norm")
1184+
# print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm)
1185+
1186+
11591187
self.past_kv_cache_recurrent_infer.clear()
11601188
self.keys_values_wm_list.clear()
11611189
torch.cuda.empty_cache()
11621190
else:
11631191
dormant_ratio_encoder = torch.tensor(0.)
1192+
avg_weight_mag_encoder = torch.tensor(0.)
1193+
avg_weight_mag_transformer = torch.tensor(0.)
1194+
avg_weight_mag_head = torch.tensor(0.)
1195+
e_rank_last_linear = torch.tensor(0.)
1196+
e_rank_sim_norm = torch.tensor(0.)
11641197

11651198
# Calculate the L2 norm of the latent state roots
11661199
latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean()
@@ -1228,7 +1261,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
12281261
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)})
12291262

12301263
# ========= logging for analysis =========
1231-
if self.analysis_dormant_ratio:
1264+
if self.analysis_dormant_ratio_weight_rank:
12321265
# Calculate dormant ratio of the world model
12331266
dormant_ratio_world_model = cal_dormant_ratio(self, {
12341267
'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())},
@@ -1396,6 +1429,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13961429
dormant_ratio_encoder=dormant_ratio_encoder,
13971430
dormant_ratio_transformer=dormant_ratio_transformer,
13981431
dormant_ratio_head=dormant_ratio_head,
1432+
avg_weight_mag_encoder = avg_weight_mag_encoder,
1433+
avg_weight_mag_transformer = avg_weight_mag_transformer,
1434+
avg_weight_mag_head = avg_weight_mag_head,
1435+
e_rank_last_linear = e_rank_last_linear,
1436+
e_rank_sim_norm = e_rank_sim_norm,
13991437
latent_state_l2_norms=latent_state_l2_norms,
14001438
policy_mu=mu,
14011439
policy_sigma=sigma,
@@ -1419,6 +1457,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
14191457
last_step_losses=last_step_losses,
14201458
dormant_ratio_transformer=dormant_ratio_transformer,
14211459
dormant_ratio_head=dormant_ratio_head,
1460+
avg_weight_mag_encoder = avg_weight_mag_encoder,
1461+
avg_weight_mag_transformer = avg_weight_mag_transformer,
1462+
avg_weight_mag_head = avg_weight_mag_head,
1463+
e_rank_last_linear = e_rank_last_linear,
1464+
e_rank_sim_norm = e_rank_sim_norm,
14221465
latent_state_l2_norms=latent_state_l2_norms,
14231466
)
14241467

lzero/model/utils.py

+129-10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,54 @@
99
import torch
1010
import torch.nn as nn
1111

12+
###############################
13+
# 1. 计算 average_weight_magnitude
14+
###############################
15+
def compute_average_weight_magnitude(model: nn.Module) -> float:
16+
"""
17+
计算模型中所有参数的平均绝对值。
18+
19+
Arguments:
20+
model: 待评估模型,类型为 nn.Module
21+
22+
Returns:
23+
平均权重绝对值(float)
24+
"""
25+
num_weights = 0
26+
# 使用模型中第一个参数的设备,保证计算时设备一致
27+
device = next(model.parameters()).device
28+
sum_weight_magnitude = torch.tensor(0.0, device=device)
29+
30+
for p in model.parameters():
31+
num_weights += p.numel()
32+
sum_weight_magnitude += torch.sum(torch.abs(p))
33+
34+
if num_weights == 0:
35+
return 0.0
36+
return sum_weight_magnitude.cpu().item() / num_weights
37+
38+
###############################
39+
# 2. 计算 effective_rank
40+
###############################
41+
def compute_effective_rank(singular_values: np.ndarray) -> float:
42+
"""
43+
根据给定的奇异值数组计算 effective rank,公式为:
44+
effective_rank = exp( - sum_i [p_i * log(p_i)] )
45+
其中 p_i 是归一化后的奇异值(p_i = s_i / ∑ s_i)
46+
47+
Arguments:
48+
singular_values: 奇异值数组,类型为 np.ndarray
49+
50+
Returns:
51+
effective rank(float)
52+
"""
53+
norm_sv = singular_values / np.sum(np.abs(singular_values))
54+
entropy = 0.0
55+
for p in norm_sv:
56+
if p > 0.0:
57+
entropy -= p * np.log(p)
58+
return np.e ** entropy
59+
1260

1361
# 定义一个 Hook 类,用来捕获中间层的输出
1462
class IntermediateOutputHook:
@@ -22,6 +70,73 @@ def __call__(self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.
2270
# 这里使用 detach 防止反向传播干扰,并转移到 CPU 便于后续统计
2371
self.outputs.append(output.detach().cpu())
2472

73+
def cal_effective_rank(
74+
model: nn.Module,
75+
inputs: Union[torch.Tensor, List[torch.Tensor]],
76+
representation_layer_name: str,
77+
) -> float:
78+
"""
79+
针对模型指定的中间层(representation 层),
80+
使用 Hook 捕获该层输出,并计算 effective rank。
81+
82+
Arguments:
83+
model: 待评估模型,应为 nn.Module 类型。
84+
inputs: 模型 forward 的输入,可以为 tensor 或 tensor-list。
85+
representation_layer_name: 模型中表示 representation 层的名称,
86+
该名称必须能够在 model.named_modules() 中找到对应模块。
87+
88+
Returns:
89+
effective rank(float)
90+
"""
91+
# 获取 representation 层模块(若名称不存在将引发 KeyError)
92+
module_dict = dict(model.named_modules())
93+
if representation_layer_name not in module_dict:
94+
raise KeyError(f"Representation layer '{representation_layer_name}' not found in model.named_modules().")
95+
representation_module = module_dict[representation_layer_name]
96+
97+
# 注册 hook
98+
hook = IntermediateOutputHook()
99+
handle = representation_module.register_forward_hook(hook)
100+
101+
# 执行 forward 推理
102+
model.eval()
103+
with torch.no_grad():
104+
if isinstance(inputs, (list, tuple)):
105+
_ = model(*inputs)
106+
else:
107+
_ = model(inputs)
108+
109+
# 注销 hook,避免内存泄露
110+
handle.remove()
111+
112+
if not hook.outputs:
113+
raise RuntimeError("No outputs captured from the representation layer.")
114+
115+
# 这里假定有一个或多个 forward(例如在 batch 或多次调用的场景),
116+
# 将所有输出在 batch 维度上拼接
117+
if len(hook.outputs) > 1:
118+
rep_tensor = torch.cat(hook.outputs, dim=0)
119+
else:
120+
rep_tensor = hook.outputs[0]
121+
122+
# 将 representation 展开为二维矩阵: (samples, features)
123+
rep_tensor = rep_tensor.view(rep_tensor.size(0), -1)
124+
125+
# 将 tensor 转换为 numpy 数组以使用 numpy.linalg.svd
126+
rep_np = rep_tensor.cpu().numpy()
127+
128+
# 计算奇异值
129+
singular_values = np.linalg.svd(rep_np, full_matrices=False, compute_uv=False)
130+
131+
# 计算 effective rank
132+
e_rank = compute_effective_rank(singular_values)
133+
134+
# 清空 hook 存储(若需要多次调用可以保持清洁状态)
135+
hook.outputs.clear()
136+
return e_rank
137+
138+
139+
25140
def compute_dormant_stats(outputs: List[torch.Tensor], threshold: float) -> Tuple[int, int]:
26141
"""
27142
对给定的一组输出(同一层可能 forward 多次)进行元素级统计。
@@ -70,18 +185,22 @@ def cal_dormant_ratio(
70185
parts["encoder"] = model.encoder
71186
if hasattr(model, "transformer"):
72187
parts["transformer"] = model.transformer
188+
73189
# 对于 head 部分,查找所有以 "head_" 开头的子模块
74-
head_modules = {}
75-
for name, module in model.named_children():
76-
if name.startswith("head_"):
77-
head_modules[name] = module
190+
# head_modules = {}
191+
# for name, module in model.named_children():
192+
# if name.startswith("head_"):
193+
# head_modules[name] = module
194+
# if head_modules:
195+
# parts["head"] = nn.ModuleDict(head_modules)
196+
197+
if hasattr(model, "head_modules"):
198+
parts["head"] = model.head_modules
78199

79-
if head_modules:
80-
parts["head"] = nn.ModuleDict(head_modules)
81-
if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"):
82-
parts["model"] = model
200+
# if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"):
201+
# parts["model"] = model
83202

84-
# 定义要捕获的目标模块类型
203+
# 定义要捕获的目标模块类型 TODO: 增加更多模块
85204
target_modules = (nn.Conv2d, nn.Linear)
86205

87206
# 用于存储各部分的 hook(字典:部分名 -> list of (module_name, hook))
@@ -117,7 +236,7 @@ def cal_dormant_ratio(
117236
for full_name, hook in hooks:
118237
layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold)
119238
# 可打印日志,也可记录更详细信息
120-
print(f"{full_name}: {layer_dormant}/{layer_total} -> {layer_dormant / layer_total * 100.0 if layer_total > 0 else 0.0}%")
239+
# print(f"{full_name}: {layer_dormant}/{layer_total} -> {layer_dormant / layer_total * 100.0 if layer_total > 0 else 0.0}%")
121240
part_total += layer_total
122241
part_dormant += layer_dormant
123242
if part_total > 0:

lzero/policy/unizero.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class UniZeroPolicy(MuZeroPolicy):
8080
device='cpu',
8181
# (bool) Whether to analyze simulation normalization.
8282
analysis_sim_norm=False,
83-
# (bool) Whether to analyze dormant ratio.
84-
analysis_dormant_ratio=False,
83+
# (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent.
84+
analysis_dormant_ratio_weight_rank=False,
8585
# (int) The shape of the action space.
8686
action_space_size=6,
8787
# (int) The size of the group, related to simulation normalization.
@@ -119,7 +119,7 @@ class UniZeroPolicy(MuZeroPolicy):
119119
# (float) The discount factor for future rewards.
120120
gamma=1,
121121
# (float) The threshold for a dormant neuron.
122-
dormant_threshold=0.025,
122+
dormant_threshold=0.01,
123123
),
124124
),
125125
# ****** common ******
@@ -415,8 +415,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
415415
)
416416

417417
weighted_total_loss = losses.loss_total
418-
for loss_name, loss_value in losses.intermediate_losses.items():
419-
self.intermediate_losses[f"{loss_name}"] = loss_value
418+
# 合并 intermediate_losses 字典,避免重复赋值
419+
self.intermediate_losses.update(losses.intermediate_losses)
420+
421+
# for loss_name, loss_value in losses.intermediate_losses.items():
422+
# self.intermediate_losses[f"{loss_name}"] = loss_value
420423

421424
obs_loss = self.intermediate_losses['loss_obs']
422425
reward_loss = self.intermediate_losses['loss_rewards']
@@ -432,6 +435,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
432435
dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder']
433436
dormant_ratio_transformer = self.intermediate_losses['dormant_ratio_transformer']
434437
dormant_ratio_head = self.intermediate_losses['dormant_ratio_head']
438+
avg_weight_mag_encoder = self.intermediate_losses['avg_weight_mag_encoder']
439+
avg_weight_mag_transformer = self.intermediate_losses['avg_weight_mag_transformer']
440+
avg_weight_mag_head = self.intermediate_losses['avg_weight_mag_head']
441+
e_rank_last_linear = self.intermediate_losses['e_rank_last_linear']
442+
e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm']
435443
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']
436444

437445
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
@@ -515,6 +523,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
515523
'analysis/dormant_ratio_transformer': dormant_ratio_transformer,#.item(),
516524
'analysis/dormant_ratio_head': dormant_ratio_head,#.item(),
517525

526+
'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder,
527+
'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer,
528+
'analysis/avg_weight_mag_head': avg_weight_mag_head,
529+
'analysis/e_rank_last_linear': e_rank_last_linear,
530+
'analysis/e_rank_sim_norm': e_rank_sim_norm,
531+
518532
'analysis/latent_state_l2_norms': latent_state_l2_norms.item(),
519533
'analysis/l2_norm_before': self.l2_norm_before,
520534
'analysis/l2_norm_after': self.l2_norm_after,
@@ -896,6 +910,12 @@ def _monitor_vars_learn(self) -> List[str]:
896910
'analysis/dormant_ratio_transformer',
897911
'analysis/dormant_ratio_head',
898912

913+
'analysis/avg_weight_mag_encoder',
914+
'analysis/avg_weight_mag_transformer',
915+
'analysis/avg_weight_mag_head',
916+
'analysis/e_rank_last_linear',
917+
'analysis/e_rank_sim_norm',
918+
899919
'analysis/latent_state_l2_norms',
900920
'analysis/l2_norm_before',
901921
'analysis/l2_norm_after',

0 commit comments

Comments
 (0)