12
12
from torch .distributions import TransformedDistribution , TanhTransform
13
13
14
14
from lzero .model .common import SimNorm
15
- from lzero .model .utils import cal_dormant_ratio
15
+ from lzero .model .utils import cal_dormant_ratio , compute_average_weight_magnitude , cal_effective_rank
16
16
from .kv_caching import KeysValues
17
17
from .slicer import Head , PolicyHeadCont
18
18
from .tokenizer import Tokenizer
@@ -97,6 +97,14 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
97
97
self .head_policy = self ._create_head (self .value_policy_tokens_pattern , self .action_space_size )
98
98
self .head_value = self ._create_head (self .value_policy_tokens_pattern , self .support_size )
99
99
100
+ # 对于 head 部分,查找所有以 "head_" 开头的子模块
101
+ self .head_modules = {}
102
+ for name , module in self .named_children ():
103
+ if name .startswith ("head_" ):
104
+ self .head_modules [name ] = module
105
+ if self .head_modules :
106
+ self .head_modules = nn .ModuleDict (self .head_modules )
107
+
100
108
# Apply weight initialization, the order is important
101
109
self .apply (lambda module : init_weights (module , norm_type = self .config .norm_type ))
102
110
self ._initialize_last_layer ()
@@ -259,7 +267,7 @@ def _initialize_config_parameters(self) -> None:
259
267
self .gamma = self .config .gamma
260
268
self .context_length = self .config .context_length
261
269
self .dormant_threshold = self .config .dormant_threshold
262
- self .analysis_dormant_ratio = self .config .analysis_dormant_ratio
270
+ self .analysis_dormant_ratio_weight_rank = self .config .analysis_dormant_ratio_weight_rank
263
271
self .num_observations_tokens = self .config .tokens_per_block - 1
264
272
self .latent_recon_loss_weight = self .config .latent_recon_loss_weight
265
273
self .perceptual_loss_weight = self .config .perceptual_loss_weight
@@ -1149,18 +1157,43 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
1149
1157
# self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne')
1150
1158
1151
1159
# ========= logging for analysis =========
1152
- if self .analysis_dormant_ratio :
1160
+ if self .analysis_dormant_ratio_weight_rank :
1153
1161
# Calculate dormant ratio of the encoder
1154
1162
shape = batch ['observations' ].shape # (..., C, H, W)
1155
1163
inputs = batch ['observations' ].contiguous ().view (- 1 , * shape [- 3 :]) # (32,5,3,64,64) -> (160,3,64,64)
1156
1164
dormant_ratio_encoder = cal_dormant_ratio (self .tokenizer .encoder , inputs .detach (),
1157
1165
dormant_threshold = self .dormant_threshold )
1158
1166
dormant_ratio_encoder = dormant_ratio_encoder ['global' ]
1167
+
1168
+ # 计算全局平均权重绝对值
1169
+ avg_weight_mag_encoder = compute_average_weight_magnitude (self .tokenizer .encoder )
1170
+ # print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder)
1171
+ # 计算全局平均权重绝对值
1172
+ avg_weight_mag_transformer = compute_average_weight_magnitude (self .transformer )
1173
+ # print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer)
1174
+ # print(f"self.head_modules:{self.head_modules}")
1175
+ avg_weight_mag_head = compute_average_weight_magnitude (self .head_modules )
1176
+ # print("Average Weight Magnitude of head:", avg_weight_mag_head)
1177
+
1178
+ # 计算 effective rank,对于 representation 层,注意:
1179
+ # representation 层在 model.named_modules() 的名称为 "representation"
1180
+ # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}")
1181
+ e_rank_last_linear = cal_effective_rank (self .tokenizer .encoder , inputs , representation_layer_name = "last_linear" )
1182
+ # print("Effective Rank of encoder_last_linear:", e_rank_last_linear)
1183
+ e_rank_sim_norm = cal_effective_rank (self .tokenizer .encoder , inputs , representation_layer_name = "sim_norm" )
1184
+ # print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm)
1185
+
1186
+
1159
1187
self .past_kv_cache_recurrent_infer .clear ()
1160
1188
self .keys_values_wm_list .clear ()
1161
1189
torch .cuda .empty_cache ()
1162
1190
else :
1163
1191
dormant_ratio_encoder = torch .tensor (0. )
1192
+ avg_weight_mag_encoder = torch .tensor (0. )
1193
+ avg_weight_mag_transformer = torch .tensor (0. )
1194
+ avg_weight_mag_head = torch .tensor (0. )
1195
+ e_rank_last_linear = torch .tensor (0. )
1196
+ e_rank_sim_norm = torch .tensor (0. )
1164
1197
1165
1198
# Calculate the L2 norm of the latent state roots
1166
1199
latent_state_l2_norms = torch .norm (obs_embeddings , p = 2 , dim = 2 ).mean ()
@@ -1228,7 +1261,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
1228
1261
outputs = self .forward ({'obs_embeddings_and_act_tokens' : (obs_embeddings , act_tokens )})
1229
1262
1230
1263
# ========= logging for analysis =========
1231
- if self .analysis_dormant_ratio :
1264
+ if self .analysis_dormant_ratio_weight_rank :
1232
1265
# Calculate dormant ratio of the world model
1233
1266
dormant_ratio_world_model = cal_dormant_ratio (self , {
1234
1267
'obs_embeddings_and_act_tokens' : (obs_embeddings .detach (), act_tokens .detach ())},
@@ -1396,6 +1429,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
1396
1429
dormant_ratio_encoder = dormant_ratio_encoder ,
1397
1430
dormant_ratio_transformer = dormant_ratio_transformer ,
1398
1431
dormant_ratio_head = dormant_ratio_head ,
1432
+ avg_weight_mag_encoder = avg_weight_mag_encoder ,
1433
+ avg_weight_mag_transformer = avg_weight_mag_transformer ,
1434
+ avg_weight_mag_head = avg_weight_mag_head ,
1435
+ e_rank_last_linear = e_rank_last_linear ,
1436
+ e_rank_sim_norm = e_rank_sim_norm ,
1399
1437
latent_state_l2_norms = latent_state_l2_norms ,
1400
1438
policy_mu = mu ,
1401
1439
policy_sigma = sigma ,
@@ -1419,6 +1457,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
1419
1457
last_step_losses = last_step_losses ,
1420
1458
dormant_ratio_transformer = dormant_ratio_transformer ,
1421
1459
dormant_ratio_head = dormant_ratio_head ,
1460
+ avg_weight_mag_encoder = avg_weight_mag_encoder ,
1461
+ avg_weight_mag_transformer = avg_weight_mag_transformer ,
1462
+ avg_weight_mag_head = avg_weight_mag_head ,
1463
+ e_rank_last_linear = e_rank_last_linear ,
1464
+ e_rank_sim_norm = e_rank_sim_norm ,
1422
1465
latent_state_l2_norms = latent_state_l2_norms ,
1423
1466
)
1424
1467
0 commit comments